Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#572 new generated docs

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_new_generated_docs
100 changed files with 6628 additions and 14555 deletions
  1. 1
    1
      docs/.buildinfo
  2. 0
    0
      docs/.nojekyll
  3. 39
    38
      docs/CONTRIBUTING.html
  4. 24
    23
      docs/LICENSE.html
  5. 36
    55
      docs/_modules/index.html
  6. 0
    999
      docs/_modules/logging.html
  7. 0
    144
      docs/_modules/super_gradients/common/abstractions/abstract_logger.html
  8. 136
    120
      docs/_modules/super_gradients/common/auto_logging/auto_logger.html
  9. 294
    0
      docs/_modules/super_gradients/common/auto_logging/console_logging.html
  10. 28
    26
      docs/_modules/super_gradients/common/aws_connection/aws_connector.html
  11. 0
    248
      docs/_modules/super_gradients/common/aws_connection/aws_secrets_manager_connector.html
  12. 133
    0
      docs/_modules/super_gradients/common/crash_handler/crash_handler.html
  13. 27
    25
      docs/_modules/super_gradients/common/data_connection/s3_connector.html
  14. 26
    24
      docs/_modules/super_gradients/common/data_interface/adnn_model_repository_data_interface.html
  15. 24
    22
      docs/_modules/super_gradients/common/data_interface/dataset_data_interface.html
  16. 24
    22
      docs/_modules/super_gradients/common/data_types/enum/deep_learning_task.html
  17. 25
    23
      docs/_modules/super_gradients/common/data_types/enum/evaluation_type.html
  18. 45
    26
      docs/_modules/super_gradients/common/data_types/enum/multi_gpu_mode.html
  19. 31
    28
      docs/_modules/super_gradients/common/data_types/enum/strict_load.html
  20. 128
    0
      docs/_modules/super_gradients/common/data_types/enum/upsample_mode.html
  21. 0
    206
      docs/_modules/super_gradients/common/decorators/deci_logger.html
  22. 24
    22
      docs/_modules/super_gradients/common/decorators/explicit_params_validator.html
  23. 26
    24
      docs/_modules/super_gradients/common/decorators/singleton.html
  24. 145
    54
      docs/_modules/super_gradients/common/environment/env_helpers.html
  25. 373
    0
      docs/_modules/super_gradients/common/object_names.html
  26. 785
    0
      docs/_modules/super_gradients/training/dataloaders/dataloaders.html
  27. 0
    207
      docs/_modules/super_gradients/training/datasets/all_datasets.html
  28. 0
    558
      docs/_modules/super_gradients/training/datasets/auto_augment.html
  29. 194
    0
      docs/_modules/super_gradients/training/datasets/classification_datasets/cifar.html
  30. 139
    0
      docs/_modules/super_gradients/training/datasets/classification_datasets/imagenet_dataset.html
  31. 31
    29
      docs/_modules/super_gradients/training/datasets/data_augmentation.html
  32. 0
    959
      docs/_modules/super_gradients/training/datasets/dataset_interfaces/dataset_interface.html
  33. 0
    811
      docs/_modules/super_gradients/training/datasets/datasets_utils.html
  34. 126
    273
      docs/_modules/super_gradients/training/datasets/detection_datasets/coco_detection.html
  35. 99
    91
      docs/_modules/super_gradients/training/datasets/detection_datasets/detection_dataset.html
  36. 109
    59
      docs/_modules/super_gradients/training/datasets/detection_datasets/pascal_voc_detection.html
  37. 0
    422
      docs/_modules/super_gradients/training/datasets/mixup.html
  38. 0
    222
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/cityscape_segmentation.html
  39. 36
    49
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/coco_segmentation.html
  40. 0
    155
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/pascal_aug_segmentation.html
  41. 128
    34
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/pascal_voc_segmentation.html
  42. 42
    141
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/segmentation_dataset.html
  43. 26
    23
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.html
  44. 30
    28
      docs/_modules/super_gradients/training/datasets/sg_dataset.html
  45. 0
    158
      docs/_modules/super_gradients/training/exceptions/dataset_exceptions.html
  46. 0
    150
      docs/_modules/super_gradients/training/exceptions/sg_model_exceptions.html
  47. 115
    95
      docs/_modules/super_gradients/training/kd_trainer/kd_trainer.html
  48. 0
    251
      docs/_modules/super_gradients/training/legacy/utils.html
  49. 25
    23
      docs/_modules/super_gradients/training/losses/bce_dice_loss.html
  50. 0
    166
      docs/_modules/super_gradients/training/losses/ddrnet_loss.html
  51. 40
    23
      docs/_modules/super_gradients/training/losses/dice_ce_edge_loss.html
  52. 25
    23
      docs/_modules/super_gradients/training/losses/focal_loss.html
  53. 34
    23
      docs/_modules/super_gradients/training/losses/kd_losses.html
  54. 29
    27
      docs/_modules/super_gradients/training/losses/label_smoothing_cross_entropy_loss.html
  55. 0
    227
      docs/_modules/super_gradients/training/losses/ohem_ce_loss.html
  56. 25
    23
      docs/_modules/super_gradients/training/losses/r_squared_loss.html
  57. 35
    24
      docs/_modules/super_gradients/training/losses/shelfnet_ohem_loss.html
  58. 35
    24
      docs/_modules/super_gradients/training/losses/shelfnet_semantic_encoding_loss.html
  59. 38
    27
      docs/_modules/super_gradients/training/losses/ssd_loss.html
  60. 0
    180
      docs/_modules/super_gradients/training/losses/yolo_v3_loss.html
  61. 0
    340
      docs/_modules/super_gradients/training/losses/yolo_v5_loss.html
  62. 531
    75
      docs/_modules/super_gradients/training/losses/yolox_loss.html
  63. 36
    31
      docs/_modules/super_gradients/training/metrics/classification_metrics.html
  64. 160
    53
      docs/_modules/super_gradients/training/metrics/detection_metrics.html
  65. 0
    226
      docs/_modules/super_gradients/training/metrics/metric_utils.html
  66. 136
    105
      docs/_modules/super_gradients/training/metrics/segmentation_metrics.html
  67. 0
    184
      docs/_modules/super_gradients/training/models/sg_module.html
  68. 0
    187
      docs/_modules/super_gradients/training/params.html
  69. 423
    249
      docs/_modules/super_gradients/training/sg_trainer/sg_trainer.html
  70. 231
    0
      docs/_modules/super_gradients/training/training_hyperparams/training_hyperparams.html
  71. 999
    0
      docs/_modules/super_gradients/training/transforms/transforms.html
  72. 0
    885
      docs/_modules/super_gradients/training/utils/callbacks.html
  73. 147
    102
      docs/_modules/super_gradients/training/utils/checkpoint_utils.html
  74. 0
    999
      docs/_modules/super_gradients/training/utils/detection_utils.html
  75. 0
    273
      docs/_modules/super_gradients/training/utils/distributed_training_utils.html
  76. 0
    268
      docs/_modules/super_gradients/training/utils/early_stopping.html
  77. 0
    245
      docs/_modules/super_gradients/training/utils/ema.html
  78. 0
    143
      docs/_modules/super_gradients/training/utils/export_utils.html
  79. 0
    339
      docs/_modules/super_gradients/training/utils/module_utils.html
  80. 0
    230
      docs/_modules/super_gradients/training/utils/optimizer_utils.html
  81. 0
    247
      docs/_modules/super_gradients/training/utils/optimizers/rmsprop_tf.html
  82. 0
    143
      docs/_modules/super_gradients/training/utils/regularization_utils.html
  83. 0
    321
      docs/_modules/super_gradients/training/utils/segmentation_utils.html
  84. 0
    449
      docs/_modules/super_gradients/training/utils/sg_model_utils.html
  85. 0
    265
      docs/_modules/super_gradients/training/utils/ssd_utils.html
  86. 98
    71
      docs/_modules/super_gradients/training/utils/utils.html
  87. 130
    0
      docs/_modules/super_gradients/training/utils/version_utils.html
  88. 0
    253
      docs/_modules/super_gradients/training/utils/weight_averaging_utils.html
  89. 0
    23
      docs/_sources/generated/super_gradients.common.abstractions.rst.txt
  90. 0
    23
      docs/_sources/generated/super_gradients.common.auto_logging.rst.txt
  91. 0
    23
      docs/_sources/generated/super_gradients.common.aws_connection.rst.txt
  92. 0
    23
      docs/_sources/generated/super_gradients.common.data_connection.rst.txt
  93. 0
    23
      docs/_sources/generated/super_gradients.common.data_interface.rst.txt
  94. 0
    23
      docs/_sources/generated/super_gradients.common.data_types.rst.txt
  95. 0
    23
      docs/_sources/generated/super_gradients.common.decorators.rst.txt
  96. 0
    23
      docs/_sources/generated/super_gradients.common.environment.rst.txt
  97. 2
    4
      docs/_sources/index.rst.txt
  98. 0
    7
      docs/_sources/modules.rst.txt
  99. 0
    21
      docs/_sources/super_gradients.common.abstractions.rst.txt
  100. 0
    21
      docs/_sources/super_gradients.common.auto_logging.rst.txt
@@ -1,4 +1,4 @@
 # Sphinx build info version 1
 # Sphinx build info version 1
 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 1ceb2b6b1de4e767037aa12749cd16be
+config: 338ace02139fbf63879758a90ee35818
 tags: 645f666f9bcd5a90fca523b33c5a78b7
 tags: 645f666f9bcd5a90fca523b33c5a78b7
Discard
    Discard
    @@ -1,13 +1,13 @@
     <!DOCTYPE html>
     <!DOCTYPE html>
     <html class="writer-html5" lang="en" >
     <html class="writer-html5" lang="en" >
     <head>
     <head>
    -  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />
    -
    +  <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>Contribution Guidelines &mdash; SuperGradients 1.0 documentation</title>
    +  <title>Contribution Guidelines &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="_static/js/html5shiv.min.js"></script>
         <script src="_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -15,7 +15,9 @@
             <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
             <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
             <script src="_static/jquery.js"></script>
             <script src="_static/jquery.js"></script>
             <script src="_static/underscore.js"></script>
             <script src="_static/underscore.js"></script>
    +        <script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="_static/doctools.js"></script>
             <script src="_static/doctools.js"></script>
    +        <script src="_static/sphinx_highlight.js"></script>
         <script src="_static/js/theme.js"></script>
         <script src="_static/js/theme.js"></script>
         <link rel="index" title="Index" href="genindex.html" />
         <link rel="index" title="Index" href="genindex.html" />
         <link rel="search" title="Search" href="search.html" /> 
         <link rel="search" title="Search" href="search.html" /> 
    @@ -36,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -86,8 +87,8 @@
               <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
               <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
                <div itemprop="articleBody">
                <div itemprop="articleBody">
                  
                  
    -  <section id="contribution-guidelines">
    -<h1>Contribution Guidelines<a class="headerlink" href="#contribution-guidelines" title="Permalink to this headline"></a></h1>
    +  <div class="section" id="contribution-guidelines">
    +<h1>Contribution Guidelines<a class="headerlink" href="#contribution-guidelines" title="Permalink to this heading"></a></h1>
     <p>We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.</p>
     <p>We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.</p>
     <p>Here are a few more things to know:</p>
     <p>Here are a few more things to know:</p>
     <ul class="simple">
     <ul class="simple">
    @@ -95,8 +96,8 @@
     <li><p><span class="xref myst">Jupyter Notebooks Contribution</span></p></li>
     <li><p><span class="xref myst">Jupyter Notebooks Contribution</span></p></li>
     <li><p><span class="xref myst">Code Style Guidelines</span></p></li>
     <li><p><span class="xref myst">Code Style Guidelines</span></p></li>
     </ul>
     </ul>
    -<section id="how-to-contribute">
    -<h2>How to Contribute<a class="headerlink" href="#how-to-contribute" title="Permalink to this headline"></a></h2>
    +<div class="section" id="how-to-contribute">
    +<h2>How to Contribute<a class="headerlink" href="#how-to-contribute" title="Permalink to this heading"></a></h2>
     <p>Here is a simple guideline to get you started with your first contribution</p>
     <p>Here is a simple guideline to get you started with your first contribution</p>
     <ol class="arabic simple">
     <ol class="arabic simple">
     <li><p>Use <a class="reference external" href="https://github.com/Deci-AI/super-gradients/issues">issues</a> to discuss the suggested changes. Create an issue describing changes if necessary and add labels to ease orientation.</p></li>
     <li><p>Use <a class="reference external" href="https://github.com/Deci-AI/super-gradients/issues">issues</a> to discuss the suggested changes. Create an issue describing changes if necessary and add labels to ease orientation.</p></li>
    @@ -130,9 +131,9 @@
     <ol class="arabic simple" start="7">
     <ol class="arabic simple" start="7">
     <li><p>Create a pull request against <b>master</b> branch.</p></li>
     <li><p>Create a pull request against <b>master</b> branch.</p></li>
     </ol>
     </ol>
    -</section>
    -<section id="jupyter-notebooks-contribution">
    -<h2>Jupyter Notebooks Contribution<a class="headerlink" href="#jupyter-notebooks-contribution" title="Permalink to this headline"></a></h2>
    +</div>
    +<div class="section" id="jupyter-notebooks-contribution">
    +<h2>Jupyter Notebooks Contribution<a class="headerlink" href="#jupyter-notebooks-contribution" title="Permalink to this heading"></a></h2>
     <p>Pulling updates from remote might cause merge conflicts with jupyter notebooks. The tool <a class="reference external" href="https://nbdime.readthedocs.io/en/latest/">nbdime</a> might solve this.</p>
     <p>Pulling updates from remote might cause merge conflicts with jupyter notebooks. The tool <a class="reference external" href="https://nbdime.readthedocs.io/en/latest/">nbdime</a> might solve this.</p>
     <ul class="simple">
     <ul class="simple">
     <li><p>Installing nbdime</p></li>
     <li><p>Installing nbdime</p></li>
    @@ -146,9 +147,9 @@
     <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">nbdiff</span> <span class="n">notebook_1</span><span class="o">.</span><span class="n">ipynb</span> <span class="n">notebook_2</span><span class="o">.</span><span class="n">ipynb</span>
     <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">nbdiff</span> <span class="n">notebook_1</span><span class="o">.</span><span class="n">ipynb</span> <span class="n">notebook_2</span><span class="o">.</span><span class="n">ipynb</span>
     </pre></div>
     </pre></div>
     </div>
     </div>
    -</section>
    -<section id="code-style-guidelines">
    -<h2>Code Style Guidelines<a class="headerlink" href="#code-style-guidelines" title="Permalink to this headline"></a></h2>
    +</div>
    +<div class="section" id="code-style-guidelines">
    +<h2>Code Style Guidelines<a class="headerlink" href="#code-style-guidelines" title="Permalink to this heading"></a></h2>
     <p>We are working hard to make sure all the code in this repository is readable, maintainable and testable.
     <p>We are working hard to make sure all the code in this repository is readable, maintainable and testable.
     We follow the Google docstring guidelines outlined on this <a class="reference external" href="https://github.com/google/styleguide/blob/gh-pages/pyguide.md#38-comments-and-docstrings">styleguide</a> page. For example:</p>
     We follow the Google docstring guidelines outlined on this <a class="reference external" href="https://github.com/google/styleguide/blob/gh-pages/pyguide.md#38-comments-and-docstrings">styleguide</a> page. For example:</p>
     <div class="highlight-python notranslate"><div class="highlight"><pre><span></span>  <span class="k">def</span> <span class="nf">python_function</span><span class="p">(</span><span class="n">first_argument</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">second_argument</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
     <div class="highlight-python notranslate"><div class="highlight"><pre><span></span>  <span class="k">def</span> <span class="nf">python_function</span><span class="p">(</span><span class="n">first_argument</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">second_argument</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
    @@ -167,9 +168,9 @@ We follow the Google docstring guidelines outlined on this <a class="reference e
     <span class="sd">      &quot;&quot;&quot;</span>
     <span class="sd">      &quot;&quot;&quot;</span>
     </pre></div>
     </pre></div>
     </div>
     </div>
    -</section>
    -<section id="documentation">
    -<h2>Documentation<a class="headerlink" href="#documentation" title="Permalink to this headline"></a></h2>
    +</div>
    +<div class="section" id="documentation">
    +<h2>Documentation<a class="headerlink" href="#documentation" title="Permalink to this heading"></a></h2>
     <p>We use  GitHub Pages for technical documentation hosting on https://deci-ai.github.io/super-gradients/welcome.html <br>
     <p>We use  GitHub Pages for technical documentation hosting on https://deci-ai.github.io/super-gradients/welcome.html <br>
     To generate the docs based on the current work tree, run: <br>
     To generate the docs based on the current work tree, run: <br>
     <code>./scripts/generate_docs.sh</code> <br><br>
     <code>./scripts/generate_docs.sh</code> <br><br>
    @@ -177,8 +178,8 @@ And the documentation will automatically update, based on <code>documentation/</
     The new documentation HTML will be generated to <code>docs/</code>. <br>
     The new documentation HTML will be generated to <code>docs/</code>. <br>
     Once <code>docs/</code> is committed and pushed, GitHub Pages will use it.<br>
     Once <code>docs/</code> is committed and pushed, GitHub Pages will use it.<br>
     The step of documentation update is currently manual.</p>
     The step of documentation update is currently manual.</p>
    -</section>
    -</section>
    +</div>
    +</div>
     
     
     
     
                </div>
                </div>
    @@ -208,4 +209,4 @@ The step of documentation update is currently manual.</p>
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -1,13 +1,13 @@
     <!DOCTYPE html>
     <!DOCTYPE html>
     <html class="writer-html5" lang="en" >
     <html class="writer-html5" lang="en" >
     <head>
     <head>
    -  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />
    -
    +  <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>&lt;no title&gt; &mdash; SuperGradients 1.0 documentation</title>
    +  <title>&lt;no title&gt; &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="_static/js/html5shiv.min.js"></script>
         <script src="_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -15,7 +15,9 @@
             <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
             <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
             <script src="_static/jquery.js"></script>
             <script src="_static/jquery.js"></script>
             <script src="_static/underscore.js"></script>
             <script src="_static/underscore.js"></script>
    +        <script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="_static/doctools.js"></script>
             <script src="_static/doctools.js"></script>
    +        <script src="_static/sphinx_highlight.js"></script>
         <script src="_static/js/theme.js"></script>
         <script src="_static/js/theme.js"></script>
         <link rel="index" title="Index" href="genindex.html" />
         <link rel="index" title="Index" href="genindex.html" />
         <link rel="search" title="Search" href="search.html" /> 
         <link rel="search" title="Search" href="search.html" /> 
    @@ -36,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -294,4 +295,4 @@ limitations under the License.</p>
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>Overview: module code &mdash; SuperGradients 1.0 documentation</title>
    +  <title>Overview: module code &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../_static/js/html5shiv.min.js"></script>
         <script src="../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
             <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
             <script src="../_static/jquery.js"></script>
             <script src="../_static/jquery.js"></script>
             <script src="../_static/underscore.js"></script>
             <script src="../_static/underscore.js"></script>
    +        <script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../_static/doctools.js"></script>
             <script src="../_static/doctools.js"></script>
    +        <script src="../_static/sphinx_highlight.js"></script>
         <script src="../_static/js/theme.js"></script>
         <script src="../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../genindex.html" />
         <link rel="index" title="Index" href="../genindex.html" />
         <link rel="search" title="Search" href="../search.html" /> 
         <link rel="search" title="Search" href="../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -85,10 +87,10 @@
                <div itemprop="articleBody">
                <div itemprop="articleBody">
                  
                  
       <h1>All modules for which code is available</h1>
       <h1>All modules for which code is available</h1>
    -<ul><li><a href="super_gradients/common/abstractions/abstract_logger.html">super_gradients.common.abstractions.abstract_logger</a></li>
    -<li><a href="super_gradients/common/auto_logging/auto_logger.html">super_gradients.common.auto_logging.auto_logger</a></li>
    +<ul><li><a href="super_gradients/common/auto_logging/auto_logger.html">super_gradients.common.auto_logging.auto_logger</a></li>
    +<li><a href="super_gradients/common/auto_logging/console_logging.html">super_gradients.common.auto_logging.console_logging</a></li>
     <li><a href="super_gradients/common/aws_connection/aws_connector.html">super_gradients.common.aws_connection.aws_connector</a></li>
     <li><a href="super_gradients/common/aws_connection/aws_connector.html">super_gradients.common.aws_connection.aws_connector</a></li>
    -<li><a href="super_gradients/common/aws_connection/aws_secrets_manager_connector.html">super_gradients.common.aws_connection.aws_secrets_manager_connector</a></li>
    +<li><a href="super_gradients/common/crash_handler/crash_handler.html">super_gradients.common.crash_handler.crash_handler</a></li>
     <li><a href="super_gradients/common/data_connection/s3_connector.html">super_gradients.common.data_connection.s3_connector</a></li>
     <li><a href="super_gradients/common/data_connection/s3_connector.html">super_gradients.common.data_connection.s3_connector</a></li>
     <li><a href="super_gradients/common/data_interface/adnn_model_repository_data_interface.html">super_gradients.common.data_interface.adnn_model_repository_data_interface</a></li>
     <li><a href="super_gradients/common/data_interface/adnn_model_repository_data_interface.html">super_gradients.common.data_interface.adnn_model_repository_data_interface</a></li>
     <li><a href="super_gradients/common/data_interface/dataset_data_interface.html">super_gradients.common.data_interface.dataset_data_interface</a></li>
     <li><a href="super_gradients/common/data_interface/dataset_data_interface.html">super_gradients.common.data_interface.dataset_data_interface</a></li>
    @@ -96,37 +98,29 @@
     <li><a href="super_gradients/common/data_types/enum/evaluation_type.html">super_gradients.common.data_types.enum.evaluation_type</a></li>
     <li><a href="super_gradients/common/data_types/enum/evaluation_type.html">super_gradients.common.data_types.enum.evaluation_type</a></li>
     <li><a href="super_gradients/common/data_types/enum/multi_gpu_mode.html">super_gradients.common.data_types.enum.multi_gpu_mode</a></li>
     <li><a href="super_gradients/common/data_types/enum/multi_gpu_mode.html">super_gradients.common.data_types.enum.multi_gpu_mode</a></li>
     <li><a href="super_gradients/common/data_types/enum/strict_load.html">super_gradients.common.data_types.enum.strict_load</a></li>
     <li><a href="super_gradients/common/data_types/enum/strict_load.html">super_gradients.common.data_types.enum.strict_load</a></li>
    -<li><a href="super_gradients/common/decorators/deci_logger.html">super_gradients.common.decorators.deci_logger</a></li>
    +<li><a href="super_gradients/common/data_types/enum/upsample_mode.html">super_gradients.common.data_types.enum.upsample_mode</a></li>
     <li><a href="super_gradients/common/decorators/explicit_params_validator.html">super_gradients.common.decorators.explicit_params_validator</a></li>
     <li><a href="super_gradients/common/decorators/explicit_params_validator.html">super_gradients.common.decorators.explicit_params_validator</a></li>
     <li><a href="super_gradients/common/decorators/singleton.html">super_gradients.common.decorators.singleton</a></li>
     <li><a href="super_gradients/common/decorators/singleton.html">super_gradients.common.decorators.singleton</a></li>
     <li><a href="super_gradients/common/environment/env_helpers.html">super_gradients.common.environment.env_helpers</a></li>
     <li><a href="super_gradients/common/environment/env_helpers.html">super_gradients.common.environment.env_helpers</a></li>
    -<li><a href="super_gradients/training/datasets/all_datasets.html">super_gradients.training.datasets.all_datasets</a></li>
    -<li><a href="super_gradients/training/datasets/auto_augment.html">super_gradients.training.datasets.auto_augment</a></li>
    +<li><a href="super_gradients/common/object_names.html">super_gradients.common.object_names</a></li>
    +<li><a href="super_gradients/training/dataloaders/dataloaders.html">super_gradients.training.dataloaders.dataloaders</a></li>
    +<li><a href="super_gradients/training/datasets/classification_datasets/cifar.html">super_gradients.training.datasets.classification_datasets.cifar</a></li>
    +<li><a href="super_gradients/training/datasets/classification_datasets/imagenet_dataset.html">super_gradients.training.datasets.classification_datasets.imagenet_dataset</a></li>
     <li><a href="super_gradients/training/datasets/data_augmentation.html">super_gradients.training.datasets.data_augmentation</a></li>
     <li><a href="super_gradients/training/datasets/data_augmentation.html">super_gradients.training.datasets.data_augmentation</a></li>
    -<li><a href="super_gradients/training/datasets/dataset_interfaces/dataset_interface.html">super_gradients.training.datasets.dataset_interfaces.dataset_interface</a></li>
    -<li><a href="super_gradients/training/datasets/datasets_utils.html">super_gradients.training.datasets.datasets_utils</a></li>
     <li><a href="super_gradients/training/datasets/detection_datasets/coco_detection.html">super_gradients.training.datasets.detection_datasets.coco_detection</a></li>
     <li><a href="super_gradients/training/datasets/detection_datasets/coco_detection.html">super_gradients.training.datasets.detection_datasets.coco_detection</a></li>
     <li><a href="super_gradients/training/datasets/detection_datasets/detection_dataset.html">super_gradients.training.datasets.detection_datasets.detection_dataset</a></li>
     <li><a href="super_gradients/training/datasets/detection_datasets/detection_dataset.html">super_gradients.training.datasets.detection_datasets.detection_dataset</a></li>
     <li><a href="super_gradients/training/datasets/detection_datasets/pascal_voc_detection.html">super_gradients.training.datasets.detection_datasets.pascal_voc_detection</a></li>
     <li><a href="super_gradients/training/datasets/detection_datasets/pascal_voc_detection.html">super_gradients.training.datasets.detection_datasets.pascal_voc_detection</a></li>
    -<li><a href="super_gradients/training/datasets/mixup.html">super_gradients.training.datasets.mixup</a></li>
    -<li><a href="super_gradients/training/datasets/segmentation_datasets/cityscape_segmentation.html">super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/coco_segmentation.html">super_gradients.training.datasets.segmentation_datasets.coco_segmentation</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/coco_segmentation.html">super_gradients.training.datasets.segmentation_datasets.coco_segmentation</a></li>
    -<li><a href="super_gradients/training/datasets/segmentation_datasets/pascal_aug_segmentation.html">super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/pascal_voc_segmentation.html">super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/pascal_voc_segmentation.html">super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/segmentation_dataset.html">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/segmentation_dataset.html">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.html">super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation</a></li>
     <li><a href="super_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.html">super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation</a></li>
     <li><a href="super_gradients/training/datasets/sg_dataset.html">super_gradients.training.datasets.sg_dataset</a></li>
     <li><a href="super_gradients/training/datasets/sg_dataset.html">super_gradients.training.datasets.sg_dataset</a></li>
    -<li><a href="super_gradients/training/exceptions/dataset_exceptions.html">super_gradients.training.exceptions.dataset_exceptions</a></li>
    -<li><a href="super_gradients/training/exceptions/sg_model_exceptions.html">super_gradients.training.exceptions.sg_model_exceptions</a></li>
    -<li><a href="super_gradients/training/kd_model/kd_model.html">super_gradients.training.kd_model.kd_model</a></li>
    -<li><a href="super_gradients/training/legacy/utils.html">super_gradients.training.legacy.utils</a></li>
    +<li><a href="super_gradients/training/kd_trainer/kd_trainer.html">super_gradients.training.kd_trainer.kd_trainer</a></li>
     <li><a href="super_gradients/training/losses/bce_dice_loss.html">super_gradients.training.losses.bce_dice_loss</a></li>
     <li><a href="super_gradients/training/losses/bce_dice_loss.html">super_gradients.training.losses.bce_dice_loss</a></li>
    -<li><a href="super_gradients/training/losses/ddrnet_loss.html">super_gradients.training.losses.ddrnet_loss</a></li>
     <li><a href="super_gradients/training/losses/dice_ce_edge_loss.html">super_gradients.training.losses.dice_ce_edge_loss</a></li>
     <li><a href="super_gradients/training/losses/dice_ce_edge_loss.html">super_gradients.training.losses.dice_ce_edge_loss</a></li>
     <li><a href="super_gradients/training/losses/focal_loss.html">super_gradients.training.losses.focal_loss</a></li>
     <li><a href="super_gradients/training/losses/focal_loss.html">super_gradients.training.losses.focal_loss</a></li>
     <li><a href="super_gradients/training/losses/kd_losses.html">super_gradients.training.losses.kd_losses</a></li>
     <li><a href="super_gradients/training/losses/kd_losses.html">super_gradients.training.losses.kd_losses</a></li>
     <li><a href="super_gradients/training/losses/label_smoothing_cross_entropy_loss.html">super_gradients.training.losses.label_smoothing_cross_entropy_loss</a></li>
     <li><a href="super_gradients/training/losses/label_smoothing_cross_entropy_loss.html">super_gradients.training.losses.label_smoothing_cross_entropy_loss</a></li>
    -<li><a href="super_gradients/training/losses/ohem_ce_loss.html">super_gradients.training.losses.ohem_ce_loss</a></li>
     <li><a href="super_gradients/training/losses/r_squared_loss.html">super_gradients.training.losses.r_squared_loss</a></li>
     <li><a href="super_gradients/training/losses/r_squared_loss.html">super_gradients.training.losses.r_squared_loss</a></li>
     <li><a href="super_gradients/training/losses/shelfnet_ohem_loss.html">super_gradients.training.losses.shelfnet_ohem_loss</a></li>
     <li><a href="super_gradients/training/losses/shelfnet_ohem_loss.html">super_gradients.training.losses.shelfnet_ohem_loss</a></li>
     <li><a href="super_gradients/training/losses/shelfnet_semantic_encoding_loss.html">super_gradients.training.losses.shelfnet_semantic_encoding_loss</a></li>
     <li><a href="super_gradients/training/losses/shelfnet_semantic_encoding_loss.html">super_gradients.training.losses.shelfnet_semantic_encoding_loss</a></li>
    @@ -134,26 +128,13 @@
     <li><a href="super_gradients/training/losses/yolox_loss.html">super_gradients.training.losses.yolox_loss</a></li>
     <li><a href="super_gradients/training/losses/yolox_loss.html">super_gradients.training.losses.yolox_loss</a></li>
     <li><a href="super_gradients/training/metrics/classification_metrics.html">super_gradients.training.metrics.classification_metrics</a></li>
     <li><a href="super_gradients/training/metrics/classification_metrics.html">super_gradients.training.metrics.classification_metrics</a></li>
     <li><a href="super_gradients/training/metrics/detection_metrics.html">super_gradients.training.metrics.detection_metrics</a></li>
     <li><a href="super_gradients/training/metrics/detection_metrics.html">super_gradients.training.metrics.detection_metrics</a></li>
    -<li><a href="super_gradients/training/metrics/metric_utils.html">super_gradients.training.metrics.metric_utils</a></li>
     <li><a href="super_gradients/training/metrics/segmentation_metrics.html">super_gradients.training.metrics.segmentation_metrics</a></li>
     <li><a href="super_gradients/training/metrics/segmentation_metrics.html">super_gradients.training.metrics.segmentation_metrics</a></li>
    -<li><a href="super_gradients/training/models/sg_module.html">super_gradients.training.models.sg_module</a></li>
    -<li><a href="super_gradients/training/sg_model/sg_model.html">super_gradients.training.sg_model.sg_model</a></li>
    -<li><a href="super_gradients/training/utils/callbacks.html">super_gradients.training.utils.callbacks</a></li>
    +<li><a href="super_gradients/training/sg_trainer/sg_trainer.html">super_gradients.training.sg_trainer.sg_trainer</a></li>
    +<li><a href="super_gradients/training/training_hyperparams/training_hyperparams.html">super_gradients.training.training_hyperparams.training_hyperparams</a></li>
    +<li><a href="super_gradients/training/transforms/transforms.html">super_gradients.training.transforms.transforms</a></li>
     <li><a href="super_gradients/training/utils/checkpoint_utils.html">super_gradients.training.utils.checkpoint_utils</a></li>
     <li><a href="super_gradients/training/utils/checkpoint_utils.html">super_gradients.training.utils.checkpoint_utils</a></li>
    -<li><a href="super_gradients/training/utils/detection_utils.html">super_gradients.training.utils.detection_utils</a></li>
    -<li><a href="super_gradients/training/utils/distributed_training_utils.html">super_gradients.training.utils.distributed_training_utils</a></li>
    -<li><a href="super_gradients/training/utils/early_stopping.html">super_gradients.training.utils.early_stopping</a></li>
    -<li><a href="super_gradients/training/utils/ema.html">super_gradients.training.utils.ema</a></li>
    -<li><a href="super_gradients/training/utils/export_utils.html">super_gradients.training.utils.export_utils</a></li>
    -<li><a href="super_gradients/training/utils/module_utils.html">super_gradients.training.utils.module_utils</a></li>
    -<li><a href="super_gradients/training/utils/optimizer_utils.html">super_gradients.training.utils.optimizer_utils</a></li>
    -<li><a href="super_gradients/training/utils/optimizers/rmsprop_tf.html">super_gradients.training.utils.optimizers.rmsprop_tf</a></li>
    -<li><a href="super_gradients/training/utils/regularization_utils.html">super_gradients.training.utils.regularization_utils</a></li>
    -<li><a href="super_gradients/training/utils/segmentation_utils.html">super_gradients.training.utils.segmentation_utils</a></li>
    -<li><a href="super_gradients/training/utils/sg_model_utils.html">super_gradients.training.utils.sg_model_utils</a></li>
    -<li><a href="super_gradients/training/utils/ssd_utils.html">super_gradients.training.utils.ssd_utils</a></li>
     <li><a href="super_gradients/training/utils/utils.html">super_gradients.training.utils.utils</a></li>
     <li><a href="super_gradients/training/utils/utils.html">super_gradients.training.utils.utils</a></li>
    -<li><a href="super_gradients/training/utils/weight_averaging_utils.html">super_gradients.training.utils.weight_averaging_utils</a></li>
    +<li><a href="super_gradients/training/utils/version_utils.html">super_gradients.training.utils.version_utils</a></li>
     </ul>
     </ul>
     
     
                </div>
                </div>
    @@ -183,4 +164,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    559
    560
    561
    562
    563
    564
    565
    566
    567
    568
    569
    570
    571
    572
    573
    574
    575
    576
    577
    578
    579
    580
    581
    582
    583
    584
    585
    586
    587
    588
    589
    590
    591
    592
    593
    594
    595
    596
    597
    598
    599
    600
    601
    602
    603
    604
    605
    606
    607
    608
    609
    610
    611
    612
    613
    614
    615
    616
    617
    618
    619
    620
    621
    622
    623
    624
    625
    626
    627
    628
    629
    630
    631
    632
    633
    634
    635
    636
    637
    638
    639
    640
    641
    642
    643
    644
    645
    646
    647
    648
    649
    650
    651
    652
    653
    654
    655
    656
    657
    658
    659
    660
    661
    662
    663
    664
    665
    666
    667
    668
    669
    670
    671
    672
    673
    674
    675
    676
    677
    678
    679
    680
    681
    682
    683
    684
    685
    686
    687
    688
    689
    690
    691
    692
    693
    694
    695
    696
    697
    698
    699
    700
    701
    702
    703
    704
    705
    706
    707
    708
    709
    710
    711
    712
    713
    714
    715
    716
    717
    718
    719
    720
    721
    722
    723
    724
    725
    726
    727
    728
    729
    730
    731
    732
    733
    734
    735
    736
    737
    738
    739
    740
    741
    742
    743
    744
    745
    746
    747
    748
    749
    750
    751
    752
    753
    754
    755
    756
    757
    758
    759
    760
    761
    762
    763
    764
    765
    766
    767
    768
    769
    770
    771
    772
    773
    774
    775
    776
    777
    778
    779
    780
    781
    782
    783
    784
    785
    786
    787
    788
    789
    790
    791
    792
    793
    794
    795
    796
    797
    798
    799
    800
    801
    802
    803
    804
    805
    806
    807
    808
    809
    810
    811
    812
    813
    814
    815
    816
    817
    818
    819
    820
    821
    822
    823
    824
    825
    826
    827
    828
    829
    830
    831
    832
    833
    834
    835
    836
    837
    838
    839
    840
    841
    842
    843
    844
    845
    846
    847
    848
    849
    850
    851
    852
    853
    854
    855
    856
    857
    858
    859
    860
    861
    862
    863
    864
    865
    866
    867
    868
    869
    870
    871
    872
    873
    874
    875
    876
    877
    878
    879
    880
    881
    882
    883
    884
    885
    886
    887
    888
    889
    890
    891
    892
    893
    894
    895
    896
    897
    898
    899
    900
    901
    902
    903
    904
    905
    906
    907
    908
    909
    910
    911
    912
    913
    914
    915
    916
    917
    918
    919
    920
    921
    922
    923
    924
    925
    926
    927
    928
    929
    930
    931
    932
    933
    934
    935
    936
    937
    938
    939
    940
    941
    942
    943
    944
    945
    946
    947
    948
    949
    950
    951
    952
    953
    954
    955
    956
    957
    958
    959
    960
    961
    962
    963
    964
    965
    966
    967
    968
    969
    970
    971
    972
    973
    974
    975
    976
    977
    978
    979
    980
    981
    982
    983
    984
    985
    986
    987
    988
    989
    990
    991
    992
    993
    994
    995
    996
    997
    998
    999
    1000
    1001
    1002
    1003
    1004
    1005
    1006
    1007
    1008
    1009
    1010
    1011
    1012
    1013
    1014
    1015
    1016
    1017
    1018
    1019
    1020
    1021
    1022
    1023
    1024
    1025
    1026
    1027
    1028
    1029
    1030
    1031
    1032
    1033
    1034
    1035
    1036
    1037
    1038
    1039
    1040
    1041
    1042
    1043
    1044
    1045
    1046
    1047
    1048
    1049
    1050
    1051
    1052
    1053
    1054
    1055
    1056
    1057
    1058
    1059
    1060
    1061
    1062
    1063
    1064
    1065
    1066
    1067
    1068
    1069
    1070
    1071
    1072
    1073
    1074
    1075
    1076
    1077
    1078
    1079
    1080
    1081
    1082
    1083
    1084
    1085
    1086
    1087
    1088
    1089
    1090
    1091
    1092
    1093
    1094
    1095
    1096
    1097
    1098
    1099
    1100
    1101
    1102
    1103
    1104
    1105
    1106
    1107
    1108
    1109
    1110
    1111
    1112
    1113
    1114
    1115
    1116
    1117
    1118
    1119
    1120
    1121
    1122
    1123
    1124
    1125
    1126
    1127
    1128
    1129
    1130
    1131
    1132
    1133
    1134
    1135
    1136
    1137
    1138
    1139
    1140
    1141
    1142
    1143
    1144
    1145
    1146
    1147
    1148
    1149
    1150
    1151
    1152
    1153
    1154
    1155
    1156
    1157
    1158
    1159
    1160
    1161
    1162
    1163
    1164
    1165
    1166
    1167
    1168
    1169
    1170
    1171
    1172
    1173
    1174
    1175
    1176
    1177
    1178
    1179
    1180
    1181
    1182
    1183
    1184
    1185
    1186
    1187
    1188
    1189
    1190
    1191
    1192
    1193
    1194
    1195
    1196
    1197
    1198
    1199
    1200
    1201
    1202
    1203
    1204
    1205
    1206
    1207
    1208
    1209
    1210
    1211
    1212
    1213
    1214
    1215
    1216
    1217
    1218
    1219
    1220
    1221
    1222
    1223
    1224
    1225
    1226
    1227
    1228
    1229
    1230
    1231
    1232
    1233
    1234
    1235
    1236
    1237
    1238
    1239
    1240
    1241
    1242
    1243
    1244
    1245
    1246
    1247
    1248
    1249
    1250
    1251
    1252
    1253
    1254
    1255
    1256
    1257
    1258
    1259
    1260
    1261
    1262
    1263
    1264
    1265
    1266
    1267
    1268
    1269
    1270
    1271
    1272
    1273
    1274
    1275
    1276
    1277
    1278
    1279
    1280
    1281
    1282
    1283
    1284
    1285
    1286
    1287
    1288
    1289
    1290
    1291
    1292
    1293
    1294
    1295
    1296
    1297
    1298
    1299
    1300
    1301
    1302
    1303
    1304
    1305
    1306
    1307
    1308
    1309
    1310
    1311
    1312
    1313
    1314
    1315
    1316
    1317
    1318
    1319
    1320
    1321
    1322
    1323
    1324
    1325
    1326
    1327
    1328
    1329
    1330
    1331
    1332
    1333
    1334
    1335
    1336
    1337
    1338
    1339
    1340
    1341
    1342
    1343
    1344
    1345
    1346
    1347
    1348
    1349
    1350
    1351
    1352
    1353
    1354
    1355
    1356
    1357
    1358
    1359
    1360
    1361
    1362
    1363
    1364
    1365
    1366
    1367
    1368
    1369
    1370
    1371
    1372
    1373
    1374
    1375
    1376
    1377
    1378
    1379
    1380
    1381
    1382
    1383
    1384
    1385
    1386
    1387
    1388
    1389
    1390
    1391
    1392
    1393
    1394
    1395
    1396
    1397
    1398
    1399
    1400
    1401
    1402
    1403
    1404
    1405
    1406
    1407
    1408
    1409
    1410
    1411
    1412
    1413
    1414
    1415
    1416
    1417
    1418
    1419
    1420
    1421
    1422
    1423
    1424
    1425
    1426
    1427
    1428
    1429
    1430
    1431
    1432
    1433
    1434
    1435
    1436
    1437
    1438
    1439
    1440
    1441
    1442
    1443
    1444
    1445
    1446
    1447
    1448
    1449
    1450
    1451
    1452
    1453
    1454
    1455
    1456
    1457
    1458
    1459
    1460
    1461
    1462
    1463
    1464
    1465
    1466
    1467
    1468
    1469
    1470
    1471
    1472
    1473
    1474
    1475
    1476
    1477
    1478
    1479
    1480
    1481
    1482
    1483
    1484
    1485
    1486
    1487
    1488
    1489
    1490
    1491
    1492
    1493
    1494
    1495
    1496
    1497
    1498
    1499
    1500
    1501
    1502
    1503
    1504
    1505
    1506
    1507
    1508
    1509
    1510
    1511
    1512
    1513
    1514
    1515
    1516
    1517
    1518
    1519
    1520
    1521
    1522
    1523
    1524
    1525
    1526
    1527
    1528
    1529
    1530
    1531
    1532
    1533
    1534
    1535
    1536
    1537
    1538
    1539
    1540
    1541
    1542
    1543
    1544
    1545
    1546
    1547
    1548
    1549
    1550
    1551
    1552
    1553
    1554
    1555
    1556
    1557
    1558
    1559
    1560
    1561
    1562
    1563
    1564
    1565
    1566
    1567
    1568
    1569
    1570
    1571
    1572
    1573
    1574
    1575
    1576
    1577
    1578
    1579
    1580
    1581
    1582
    1583
    1584
    1585
    1586
    1587
    1588
    1589
    1590
    1591
    1592
    1593
    1594
    1595
    1596
    1597
    1598
    1599
    1600
    1601
    1602
    1603
    1604
    1605
    1606
    1607
    1608
    1609
    1610
    1611
    1612
    1613
    1614
    1615
    1616
    1617
    1618
    1619
    1620
    1621
    1622
    1623
    1624
    1625
    1626
    1627
    1628
    1629
    1630
    1631
    1632
    1633
    1634
    1635
    1636
    1637
    1638
    1639
    1640
    1641
    1642
    1643
    1644
    1645
    1646
    1647
    1648
    1649
    1650
    1651
    1652
    1653
    1654
    1655
    1656
    1657
    1658
    1659
    1660
    1661
    1662
    1663
    1664
    1665
    1666
    1667
    1668
    1669
    1670
    1671
    1672
    1673
    1674
    1675
    1676
    1677
    1678
    1679
    1680
    1681
    1682
    1683
    1684
    1685
    1686
    1687
    1688
    1689
    1690
    1691
    1692
    1693
    1694
    1695
    1696
    1697
    1698
    1699
    1700
    1701
    1702
    1703
    1704
    1705
    1706
    1707
    1708
    1709
    1710
    1711
    1712
    1713
    1714
    1715
    1716
    1717
    1718
    1719
    1720
    1721
    1722
    1723
    1724
    1725
    1726
    1727
    1728
    1729
    1730
    1731
    1732
    1733
    1734
    1735
    1736
    1737
    1738
    1739
    1740
    1741
    1742
    1743
    1744
    1745
    1746
    1747
    1748
    1749
    1750
    1751
    1752
    1753
    1754
    1755
    1756
    1757
    1758
    1759
    1760
    1761
    1762
    1763
    1764
    1765
    1766
    1767
    1768
    1769
    1770
    1771
    1772
    1773
    1774
    1775
    1776
    1777
    1778
    1779
    1780
    1781
    1782
    1783
    1784
    1785
    1786
    1787
    1788
    1789
    1790
    1791
    1792
    1793
    1794
    1795
    1796
    1797
    1798
    1799
    1800
    1801
    1802
    1803
    1804
    1805
    1806
    1807
    1808
    1809
    1810
    1811
    1812
    1813
    1814
    1815
    1816
    1817
    1818
    1819
    1820
    1821
    1822
    1823
    1824
    1825
    1826
    1827
    1828
    1829
    1830
    1831
    1832
    1833
    1834
    1835
    1836
    1837
    1838
    1839
    1840
    1841
    1842
    1843
    1844
    1845
    1846
    1847
    1848
    1849
    1850
    1851
    1852
    1853
    1854
    1855
    1856
    1857
    1858
    1859
    1860
    1861
    1862
    1863
    1864
    1865
    1866
    1867
    1868
    1869
    1870
    1871
    1872
    1873
    1874
    1875
    1876
    1877
    1878
    1879
    1880
    1881
    1882
    1883
    1884
    1885
    1886
    1887
    1888
    1889
    1890
    1891
    1892
    1893
    1894
    1895
    1896
    1897
    1898
    1899
    1900
    1901
    1902
    1903
    1904
    1905
    1906
    1907
    1908
    1909
    1910
    1911
    1912
    1913
    1914
    1915
    1916
    1917
    1918
    1919
    1920
    1921
    1922
    1923
    1924
    1925
    1926
    1927
    1928
    1929
    1930
    1931
    1932
    1933
    1934
    1935
    1936
    1937
    1938
    1939
    1940
    1941
    1942
    1943
    1944
    1945
    1946
    1947
    1948
    1949
    1950
    1951
    1952
    1953
    1954
    1955
    1956
    1957
    1958
    1959
    1960
    1961
    1962
    1963
    1964
    1965
    1966
    1967
    1968
    1969
    1970
    1971
    1972
    1973
    1974
    1975
    1976
    1977
    1978
    1979
    1980
    1981
    1982
    1983
    1984
    1985
    1986
    1987
    1988
    1989
    1990
    1991
    1992
    1993
    1994
    1995
    1996
    1997
    1998
    1999
    2000
    2001
    2002
    2003
    2004
    2005
    2006
    2007
    2008
    2009
    2010
    2011
    2012
    2013
    2014
    2015
    2016
    2017
    2018
    2019
    2020
    2021
    2022
    2023
    2024
    2025
    2026
    2027
    2028
    2029
    2030
    2031
    2032
    2033
    2034
    2035
    2036
    2037
    2038
    2039
    2040
    2041
    2042
    2043
    2044
    2045
    2046
    2047
    2048
    2049
    2050
    2051
    2052
    2053
    2054
    2055
    2056
    2057
    2058
    2059
    2060
    2061
    2062
    2063
    2064
    2065
    2066
    2067
    2068
    2069
    2070
    2071
    2072
    2073
    2074
    2075
    2076
    2077
    2078
    2079
    2080
    2081
    2082
    2083
    2084
    2085
    2086
    2087
    2088
    2089
    2090
    2091
    2092
    2093
    2094
    2095
    2096
    2097
    2098
    2099
    2100
    2101
    2102
    2103
    2104
    2105
    2106
    2107
    2108
    2109
    2110
    2111
    2112
    2113
    2114
    2115
    2116
    2117
    2118
    2119
    2120
    2121
    2122
    2123
    2124
    2125
    2126
    2127
    2128
    2129
    2130
    2131
    2132
    2133
    2134
    2135
    2136
    2137
    2138
    2139
    2140
    2141
    2142
    2143
    2144
    2145
    2146
    2147
    2148
    2149
    2150
    2151
    2152
    2153
    2154
    2155
    2156
    2157
    2158
    2159
    2160
    2161
    2162
    2163
    2164
    2165
    2166
    2167
    2168
    2169
    2170
    2171
    2172
    2173
    2174
    2175
    2176
    2177
    2178
    2179
    2180
    2181
    2182
    2183
    2184
    2185
    2186
    2187
    2188
    2189
    2190
    2191
    2192
    2193
    2194
    2195
    2196
    2197
    2198
    2199
    2200
    2201
    2202
    2203
    2204
    2205
    2206
    2207
    2208
    2209
    2210
    2211
    2212
    2213
    2214
    2215
    2216
    2217
    2218
    2219
    2220
    2221
    2222
    2223
    2224
    2225
    2226
    2227
    2228
    2229
    2230
    2231
    2232
    2233
    2234
    2235
    2236
    2237
    2238
    2239
    2240
    2241
    2242
    2243
    2244
    2245
    2246
    2247
    2248
    2249
    2250
    2251
    2252
    2253
    2254
    2255
    2256
    2257
    2258
    2259
    2260
    2261
    2262
    2263
    2264
    2265
    2266
    2267
    2268
    2269
    2270
    2271
    2272
    2273
    2274
    2275
    2276
    2277
    2278
    2279
    2280
    2281
    2282
    2283
    2284
    2285
    2286
    2287
    2288
    2289
    2290
    2291
    2292
    2293
    2294
    2295
    2296
    2297
    2298
    2299
    2300
    2301
    2302
    2303
    2304
    2305
    2306
    2307
    2308
    2309
    2310
    2311
    2312
    2313
    2314
    2315
    2316
    2317
    2318
    2319
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>logging &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    14. <script src="../_static/jquery.js"></script>
    15. <script src="../_static/underscore.js"></script>
    16. <script src="../_static/doctools.js"></script>
    17. <script src="../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../genindex.html" />
    19. <link rel="search" title="Search" href="../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../welcome.html">SuperGradients</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../super_gradients.common.html">Common</a></li>
    40. <li class="toctree-l1"><a class="reference internal" href="../super_gradients.training.html">Training</a></li>
    41. </ul>
    42. </div>
    43. </div>
    44. </nav>
    45. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    46. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    47. <a href="../index.html">SuperGradients</a>
    48. </nav>
    49. <div class="wy-nav-content">
    50. <div class="rst-content">
    51. <div role="navigation" aria-label="Page navigation">
    52. <ul class="wy-breadcrumbs">
    53. <li><a href="../index.html" class="icon icon-home"></a> &raquo;</li>
    54. <li><a href="index.html">Module code</a> &raquo;</li>
    55. <li>logging</li>
    56. <li class="wy-breadcrumbs-aside">
    57. </li>
    58. </ul>
    59. <hr/>
    60. </div>
    61. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    62. <div itemprop="articleBody">
    63. <h1>Source code for logging</h1><div class="highlight"><pre>
    64. <span></span><span class="c1"># Copyright 2001-2019 by Vinay Sajip. All Rights Reserved.</span>
    65. <span class="c1">#</span>
    66. <span class="c1"># Permission to use, copy, modify, and distribute this software and its</span>
    67. <span class="c1"># documentation for any purpose and without fee is hereby granted,</span>
    68. <span class="c1"># provided that the above copyright notice appear in all copies and that</span>
    69. <span class="c1"># both that copyright notice and this permission notice appear in</span>
    70. <span class="c1"># supporting documentation, and that the name of Vinay Sajip</span>
    71. <span class="c1"># not be used in advertising or publicity pertaining to distribution</span>
    72. <span class="c1"># of the software without specific, written prior permission.</span>
    73. <span class="c1"># VINAY SAJIP DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING</span>
    74. <span class="c1"># ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL</span>
    75. <span class="c1"># VINAY SAJIP BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR</span>
    76. <span class="c1"># ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER</span>
    77. <span class="c1"># IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT</span>
    78. <span class="c1"># OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.</span>
    79. <span class="sd">&quot;&quot;&quot;</span>
    80. <span class="sd">Logging package for Python. Based on PEP 282 and comments thereto in</span>
    81. <span class="sd">comp.lang.python.</span>
    82. <span class="sd">Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved.</span>
    83. <span class="sd">To use, simply &#39;import logging&#39; and log away!</span>
    84. <span class="sd">&quot;&quot;&quot;</span>
    85. <span class="kn">import</span> <span class="nn">sys</span><span class="o">,</span> <span class="nn">os</span><span class="o">,</span> <span class="nn">time</span><span class="o">,</span> <span class="nn">io</span><span class="o">,</span> <span class="nn">re</span><span class="o">,</span> <span class="nn">traceback</span><span class="o">,</span> <span class="nn">warnings</span><span class="o">,</span> <span class="nn">weakref</span><span class="o">,</span> <span class="nn">collections.abc</span>
    86. <span class="kn">from</span> <span class="nn">string</span> <span class="kn">import</span> <span class="n">Template</span>
    87. <span class="kn">from</span> <span class="nn">string</span> <span class="kn">import</span> <span class="n">Formatter</span> <span class="k">as</span> <span class="n">StrFormatter</span>
    88. <span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;BASIC_FORMAT&#39;</span><span class="p">,</span> <span class="s1">&#39;BufferingFormatter&#39;</span><span class="p">,</span> <span class="s1">&#39;CRITICAL&#39;</span><span class="p">,</span> <span class="s1">&#39;DEBUG&#39;</span><span class="p">,</span> <span class="s1">&#39;ERROR&#39;</span><span class="p">,</span>
    89. <span class="s1">&#39;FATAL&#39;</span><span class="p">,</span> <span class="s1">&#39;FileHandler&#39;</span><span class="p">,</span> <span class="s1">&#39;Filter&#39;</span><span class="p">,</span> <span class="s1">&#39;Formatter&#39;</span><span class="p">,</span> <span class="s1">&#39;Handler&#39;</span><span class="p">,</span> <span class="s1">&#39;INFO&#39;</span><span class="p">,</span>
    90. <span class="s1">&#39;LogRecord&#39;</span><span class="p">,</span> <span class="s1">&#39;Logger&#39;</span><span class="p">,</span> <span class="s1">&#39;LoggerAdapter&#39;</span><span class="p">,</span> <span class="s1">&#39;NOTSET&#39;</span><span class="p">,</span> <span class="s1">&#39;NullHandler&#39;</span><span class="p">,</span>
    91. <span class="s1">&#39;StreamHandler&#39;</span><span class="p">,</span> <span class="s1">&#39;WARN&#39;</span><span class="p">,</span> <span class="s1">&#39;WARNING&#39;</span><span class="p">,</span> <span class="s1">&#39;addLevelName&#39;</span><span class="p">,</span> <span class="s1">&#39;basicConfig&#39;</span><span class="p">,</span>
    92. <span class="s1">&#39;captureWarnings&#39;</span><span class="p">,</span> <span class="s1">&#39;critical&#39;</span><span class="p">,</span> <span class="s1">&#39;debug&#39;</span><span class="p">,</span> <span class="s1">&#39;disable&#39;</span><span class="p">,</span> <span class="s1">&#39;error&#39;</span><span class="p">,</span>
    93. <span class="s1">&#39;exception&#39;</span><span class="p">,</span> <span class="s1">&#39;fatal&#39;</span><span class="p">,</span> <span class="s1">&#39;getLevelName&#39;</span><span class="p">,</span> <span class="s1">&#39;getLogger&#39;</span><span class="p">,</span> <span class="s1">&#39;getLoggerClass&#39;</span><span class="p">,</span>
    94. <span class="s1">&#39;info&#39;</span><span class="p">,</span> <span class="s1">&#39;log&#39;</span><span class="p">,</span> <span class="s1">&#39;makeLogRecord&#39;</span><span class="p">,</span> <span class="s1">&#39;setLoggerClass&#39;</span><span class="p">,</span> <span class="s1">&#39;shutdown&#39;</span><span class="p">,</span>
    95. <span class="s1">&#39;warn&#39;</span><span class="p">,</span> <span class="s1">&#39;warning&#39;</span><span class="p">,</span> <span class="s1">&#39;getLogRecordFactory&#39;</span><span class="p">,</span> <span class="s1">&#39;setLogRecordFactory&#39;</span><span class="p">,</span>
    96. <span class="s1">&#39;lastResort&#39;</span><span class="p">,</span> <span class="s1">&#39;raiseExceptions&#39;</span><span class="p">]</span>
    97. <span class="kn">import</span> <span class="nn">threading</span>
    98. <span class="n">__author__</span> <span class="o">=</span> <span class="s2">&quot;Vinay Sajip &lt;vinay_sajip@red-dove.com&gt;&quot;</span>
    99. <span class="n">__status__</span> <span class="o">=</span> <span class="s2">&quot;production&quot;</span>
    100. <span class="c1"># The following module attributes are no longer updated.</span>
    101. <span class="n">__version__</span> <span class="o">=</span> <span class="s2">&quot;0.5.1.2&quot;</span>
    102. <span class="n">__date__</span> <span class="o">=</span> <span class="s2">&quot;07 February 2010&quot;</span>
    103. <span class="c1">#---------------------------------------------------------------------------</span>
    104. <span class="c1"># Miscellaneous module data</span>
    105. <span class="c1">#---------------------------------------------------------------------------</span>
    106. <span class="c1">#</span>
    107. <span class="c1">#_startTime is used as the base when calculating the relative time of events</span>
    108. <span class="c1">#</span>
    109. <span class="n">_startTime</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    110. <span class="c1">#</span>
    111. <span class="c1">#raiseExceptions is used to see if exceptions during handling should be</span>
    112. <span class="c1">#propagated</span>
    113. <span class="c1">#</span>
    114. <span class="n">raiseExceptions</span> <span class="o">=</span> <span class="kc">True</span>
    115. <span class="c1">#</span>
    116. <span class="c1"># If you don&#39;t want threading information in the log, set this to zero</span>
    117. <span class="c1">#</span>
    118. <span class="n">logThreads</span> <span class="o">=</span> <span class="kc">True</span>
    119. <span class="c1">#</span>
    120. <span class="c1"># If you don&#39;t want multiprocessing information in the log, set this to zero</span>
    121. <span class="c1">#</span>
    122. <span class="n">logMultiprocessing</span> <span class="o">=</span> <span class="kc">True</span>
    123. <span class="c1">#</span>
    124. <span class="c1"># If you don&#39;t want process information in the log, set this to zero</span>
    125. <span class="c1">#</span>
    126. <span class="n">logProcesses</span> <span class="o">=</span> <span class="kc">True</span>
    127. <span class="c1">#---------------------------------------------------------------------------</span>
    128. <span class="c1"># Level related stuff</span>
    129. <span class="c1">#---------------------------------------------------------------------------</span>
    130. <span class="c1">#</span>
    131. <span class="c1"># Default levels and level names, these can be replaced with any positive set</span>
    132. <span class="c1"># of values having corresponding names. There is a pseudo-level, NOTSET, which</span>
    133. <span class="c1"># is only really there as a lower limit for user-defined levels. Handlers and</span>
    134. <span class="c1"># loggers are initialized with NOTSET so that they will log all messages, even</span>
    135. <span class="c1"># at user-defined levels.</span>
    136. <span class="c1">#</span>
    137. <span class="n">CRITICAL</span> <span class="o">=</span> <span class="mi">50</span>
    138. <span class="n">FATAL</span> <span class="o">=</span> <span class="n">CRITICAL</span>
    139. <span class="n">ERROR</span> <span class="o">=</span> <span class="mi">40</span>
    140. <span class="n">WARNING</span> <span class="o">=</span> <span class="mi">30</span>
    141. <span class="n">WARN</span> <span class="o">=</span> <span class="n">WARNING</span>
    142. <span class="n">INFO</span> <span class="o">=</span> <span class="mi">20</span>
    143. <span class="n">DEBUG</span> <span class="o">=</span> <span class="mi">10</span>
    144. <span class="n">NOTSET</span> <span class="o">=</span> <span class="mi">0</span>
    145. <span class="n">_levelToName</span> <span class="o">=</span> <span class="p">{</span>
    146. <span class="n">CRITICAL</span><span class="p">:</span> <span class="s1">&#39;CRITICAL&#39;</span><span class="p">,</span>
    147. <span class="n">ERROR</span><span class="p">:</span> <span class="s1">&#39;ERROR&#39;</span><span class="p">,</span>
    148. <span class="n">WARNING</span><span class="p">:</span> <span class="s1">&#39;WARNING&#39;</span><span class="p">,</span>
    149. <span class="n">INFO</span><span class="p">:</span> <span class="s1">&#39;INFO&#39;</span><span class="p">,</span>
    150. <span class="n">DEBUG</span><span class="p">:</span> <span class="s1">&#39;DEBUG&#39;</span><span class="p">,</span>
    151. <span class="n">NOTSET</span><span class="p">:</span> <span class="s1">&#39;NOTSET&#39;</span><span class="p">,</span>
    152. <span class="p">}</span>
    153. <span class="n">_nameToLevel</span> <span class="o">=</span> <span class="p">{</span>
    154. <span class="s1">&#39;CRITICAL&#39;</span><span class="p">:</span> <span class="n">CRITICAL</span><span class="p">,</span>
    155. <span class="s1">&#39;FATAL&#39;</span><span class="p">:</span> <span class="n">FATAL</span><span class="p">,</span>
    156. <span class="s1">&#39;ERROR&#39;</span><span class="p">:</span> <span class="n">ERROR</span><span class="p">,</span>
    157. <span class="s1">&#39;WARN&#39;</span><span class="p">:</span> <span class="n">WARNING</span><span class="p">,</span>
    158. <span class="s1">&#39;WARNING&#39;</span><span class="p">:</span> <span class="n">WARNING</span><span class="p">,</span>
    159. <span class="s1">&#39;INFO&#39;</span><span class="p">:</span> <span class="n">INFO</span><span class="p">,</span>
    160. <span class="s1">&#39;DEBUG&#39;</span><span class="p">:</span> <span class="n">DEBUG</span><span class="p">,</span>
    161. <span class="s1">&#39;NOTSET&#39;</span><span class="p">:</span> <span class="n">NOTSET</span><span class="p">,</span>
    162. <span class="p">}</span>
    163. <span class="k">def</span> <span class="nf">getLevelName</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
    164. <span class="sd">&quot;&quot;&quot;</span>
    165. <span class="sd"> Return the textual or numeric representation of logging level &#39;level&#39;.</span>
    166. <span class="sd"> If the level is one of the predefined levels (CRITICAL, ERROR, WARNING,</span>
    167. <span class="sd"> INFO, DEBUG) then you get the corresponding string. If you have</span>
    168. <span class="sd"> associated levels with names using addLevelName then the name you have</span>
    169. <span class="sd"> associated with &#39;level&#39; is returned.</span>
    170. <span class="sd"> If a numeric value corresponding to one of the defined levels is passed</span>
    171. <span class="sd"> in, the corresponding string representation is returned.</span>
    172. <span class="sd"> If a string representation of the level is passed in, the corresponding</span>
    173. <span class="sd"> numeric value is returned.</span>
    174. <span class="sd"> If no matching numeric or string value is passed in, the string</span>
    175. <span class="sd"> &#39;Level %s&#39; % level is returned.</span>
    176. <span class="sd"> &quot;&quot;&quot;</span>
    177. <span class="c1"># See Issues #22386, #27937 and #29220 for why it&#39;s this way</span>
    178. <span class="n">result</span> <span class="o">=</span> <span class="n">_levelToName</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    179. <span class="k">if</span> <span class="n">result</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    180. <span class="k">return</span> <span class="n">result</span>
    181. <span class="n">result</span> <span class="o">=</span> <span class="n">_nameToLevel</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    182. <span class="k">if</span> <span class="n">result</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    183. <span class="k">return</span> <span class="n">result</span>
    184. <span class="k">return</span> <span class="s2">&quot;Level </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">level</span>
    185. <span class="k">def</span> <span class="nf">addLevelName</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">levelName</span><span class="p">):</span>
    186. <span class="sd">&quot;&quot;&quot;</span>
    187. <span class="sd"> Associate &#39;levelName&#39; with &#39;level&#39;.</span>
    188. <span class="sd"> This is used when converting levels to text during message formatting.</span>
    189. <span class="sd"> &quot;&quot;&quot;</span>
    190. <span class="n">_acquireLock</span><span class="p">()</span>
    191. <span class="k">try</span><span class="p">:</span> <span class="c1">#unlikely to cause an exception, but you never know...</span>
    192. <span class="n">_levelToName</span><span class="p">[</span><span class="n">level</span><span class="p">]</span> <span class="o">=</span> <span class="n">levelName</span>
    193. <span class="n">_nameToLevel</span><span class="p">[</span><span class="n">levelName</span><span class="p">]</span> <span class="o">=</span> <span class="n">level</span>
    194. <span class="k">finally</span><span class="p">:</span>
    195. <span class="n">_releaseLock</span><span class="p">()</span>
    196. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">sys</span><span class="p">,</span> <span class="s1">&#39;_getframe&#39;</span><span class="p">):</span>
    197. <span class="n">currentframe</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">sys</span><span class="o">.</span><span class="n">_getframe</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
    198. <span class="k">else</span><span class="p">:</span> <span class="c1">#pragma: no cover</span>
    199. <span class="k">def</span> <span class="nf">currentframe</span><span class="p">():</span>
    200. <span class="sd">&quot;&quot;&quot;Return the frame object for the caller&#39;s stack frame.&quot;&quot;&quot;</span>
    201. <span class="k">try</span><span class="p">:</span>
    202. <span class="k">raise</span> <span class="ne">Exception</span>
    203. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span>
    204. <span class="k">return</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">tb_frame</span><span class="o">.</span><span class="n">f_back</span>
    205. <span class="c1">#</span>
    206. <span class="c1"># _srcfile is used when walking the stack to check when we&#39;ve got the first</span>
    207. <span class="c1"># caller stack frame, by skipping frames whose filename is that of this</span>
    208. <span class="c1"># module&#39;s source. It therefore should contain the filename of this module&#39;s</span>
    209. <span class="c1"># source file.</span>
    210. <span class="c1">#</span>
    211. <span class="c1"># Ordinarily we would use __file__ for this, but frozen modules don&#39;t always</span>
    212. <span class="c1"># have __file__ set, for some reason (see Issue #21736). Thus, we get the</span>
    213. <span class="c1"># filename from a handy code object from a function defined in this module.</span>
    214. <span class="c1"># (There&#39;s no particular reason for picking addLevelName.)</span>
    215. <span class="c1">#</span>
    216. <span class="n">_srcfile</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">normcase</span><span class="p">(</span><span class="n">addLevelName</span><span class="o">.</span><span class="vm">__code__</span><span class="o">.</span><span class="n">co_filename</span><span class="p">)</span>
    217. <span class="c1"># _srcfile is only used in conjunction with sys._getframe().</span>
    218. <span class="c1"># To provide compatibility with older versions of Python, set _srcfile</span>
    219. <span class="c1"># to None if _getframe() is not available; this value will prevent</span>
    220. <span class="c1"># findCaller() from being called. You can also do this if you want to avoid</span>
    221. <span class="c1"># the overhead of fetching caller information, even when _getframe() is</span>
    222. <span class="c1"># available.</span>
    223. <span class="c1">#if not hasattr(sys, &#39;_getframe&#39;):</span>
    224. <span class="c1"># _srcfile = None</span>
    225. <span class="k">def</span> <span class="nf">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
    226. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
    227. <span class="n">rv</span> <span class="o">=</span> <span class="n">level</span>
    228. <span class="k">elif</span> <span class="nb">str</span><span class="p">(</span><span class="n">level</span><span class="p">)</span> <span class="o">==</span> <span class="n">level</span><span class="p">:</span>
    229. <span class="k">if</span> <span class="n">level</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_nameToLevel</span><span class="p">:</span>
    230. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unknown level: </span><span class="si">%r</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">level</span><span class="p">)</span>
    231. <span class="n">rv</span> <span class="o">=</span> <span class="n">_nameToLevel</span><span class="p">[</span><span class="n">level</span><span class="p">]</span>
    232. <span class="k">else</span><span class="p">:</span>
    233. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;Level not an integer or a valid string: </span><span class="si">%r</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">level</span><span class="p">)</span>
    234. <span class="k">return</span> <span class="n">rv</span>
    235. <span class="c1">#---------------------------------------------------------------------------</span>
    236. <span class="c1"># Thread-related stuff</span>
    237. <span class="c1">#---------------------------------------------------------------------------</span>
    238. <span class="c1">#</span>
    239. <span class="c1">#_lock is used to serialize access to shared data structures in this module.</span>
    240. <span class="c1">#This needs to be an RLock because fileConfig() creates and configures</span>
    241. <span class="c1">#Handlers, and so might arbitrary user threads. Since Handler code updates the</span>
    242. <span class="c1">#shared dictionary _handlers, it needs to acquire the lock. But if configuring,</span>
    243. <span class="c1">#the lock would already have been acquired - so we need an RLock.</span>
    244. <span class="c1">#The same argument applies to Loggers and Manager.loggerDict.</span>
    245. <span class="c1">#</span>
    246. <span class="n">_lock</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">RLock</span><span class="p">()</span>
    247. <span class="k">def</span> <span class="nf">_acquireLock</span><span class="p">():</span>
    248. <span class="sd">&quot;&quot;&quot;</span>
    249. <span class="sd"> Acquire the module-level lock for serializing access to shared data.</span>
    250. <span class="sd"> This should be released with _releaseLock().</span>
    251. <span class="sd"> &quot;&quot;&quot;</span>
    252. <span class="k">if</span> <span class="n">_lock</span><span class="p">:</span>
    253. <span class="n">_lock</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
    254. <span class="k">def</span> <span class="nf">_releaseLock</span><span class="p">():</span>
    255. <span class="sd">&quot;&quot;&quot;</span>
    256. <span class="sd"> Release the module-level lock acquired by calling _acquireLock().</span>
    257. <span class="sd"> &quot;&quot;&quot;</span>
    258. <span class="k">if</span> <span class="n">_lock</span><span class="p">:</span>
    259. <span class="n">_lock</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
    260. <span class="c1"># Prevent a held logging lock from blocking a child from logging.</span>
    261. <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">os</span><span class="p">,</span> <span class="s1">&#39;register_at_fork&#39;</span><span class="p">):</span> <span class="c1"># Windows and friends.</span>
    262. <span class="k">def</span> <span class="nf">_register_at_fork_reinit_lock</span><span class="p">(</span><span class="n">instance</span><span class="p">):</span>
    263. <span class="k">pass</span> <span class="c1"># no-op when os.register_at_fork does not exist.</span>
    264. <span class="k">else</span><span class="p">:</span>
    265. <span class="c1"># A collection of instances with a _at_fork_reinit method (logging.Handler)</span>
    266. <span class="c1"># to be called in the child after forking. The weakref avoids us keeping</span>
    267. <span class="c1"># discarded Handler instances alive.</span>
    268. <span class="n">_at_fork_reinit_lock_weakset</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">WeakSet</span><span class="p">()</span>
    269. <span class="k">def</span> <span class="nf">_register_at_fork_reinit_lock</span><span class="p">(</span><span class="n">instance</span><span class="p">):</span>
    270. <span class="n">_acquireLock</span><span class="p">()</span>
    271. <span class="k">try</span><span class="p">:</span>
    272. <span class="n">_at_fork_reinit_lock_weakset</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">instance</span><span class="p">)</span>
    273. <span class="k">finally</span><span class="p">:</span>
    274. <span class="n">_releaseLock</span><span class="p">()</span>
    275. <span class="k">def</span> <span class="nf">_after_at_fork_child_reinit_locks</span><span class="p">():</span>
    276. <span class="k">for</span> <span class="n">handler</span> <span class="ow">in</span> <span class="n">_at_fork_reinit_lock_weakset</span><span class="p">:</span>
    277. <span class="n">handler</span><span class="o">.</span><span class="n">_at_fork_reinit</span><span class="p">()</span>
    278. <span class="c1"># _acquireLock() was called in the parent before forking.</span>
    279. <span class="c1"># The lock is reinitialized to unlocked state.</span>
    280. <span class="n">_lock</span><span class="o">.</span><span class="n">_at_fork_reinit</span><span class="p">()</span>
    281. <span class="n">os</span><span class="o">.</span><span class="n">register_at_fork</span><span class="p">(</span><span class="n">before</span><span class="o">=</span><span class="n">_acquireLock</span><span class="p">,</span>
    282. <span class="n">after_in_child</span><span class="o">=</span><span class="n">_after_at_fork_child_reinit_locks</span><span class="p">,</span>
    283. <span class="n">after_in_parent</span><span class="o">=</span><span class="n">_releaseLock</span><span class="p">)</span>
    284. <span class="c1">#---------------------------------------------------------------------------</span>
    285. <span class="c1"># The logging record</span>
    286. <span class="c1">#---------------------------------------------------------------------------</span>
    287. <span class="k">class</span> <span class="nc">LogRecord</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    288. <span class="sd">&quot;&quot;&quot;</span>
    289. <span class="sd"> A LogRecord instance represents an event being logged.</span>
    290. <span class="sd"> LogRecord instances are created every time something is logged. They</span>
    291. <span class="sd"> contain all the information pertinent to the event being logged. The</span>
    292. <span class="sd"> main information passed in is in msg and args, which are combined</span>
    293. <span class="sd"> using str(msg) % args to create the message field of the record. The</span>
    294. <span class="sd"> record also includes information such as when the record was created,</span>
    295. <span class="sd"> the source line where the logging call was made, and any exception</span>
    296. <span class="sd"> information to be logged.</span>
    297. <span class="sd"> &quot;&quot;&quot;</span>
    298. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">pathname</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span>
    299. <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sinfo</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    300. <span class="sd">&quot;&quot;&quot;</span>
    301. <span class="sd"> Initialize a logging record with interesting information.</span>
    302. <span class="sd"> &quot;&quot;&quot;</span>
    303. <span class="n">ct</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    304. <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
    305. <span class="bp">self</span><span class="o">.</span><span class="n">msg</span> <span class="o">=</span> <span class="n">msg</span>
    306. <span class="c1">#</span>
    307. <span class="c1"># The following statement allows passing of a dictionary as a sole</span>
    308. <span class="c1"># argument, so that you can do something like</span>
    309. <span class="c1"># logging.debug(&quot;a %(a)d b %(b)s&quot;, {&#39;a&#39;:1, &#39;b&#39;:2})</span>
    310. <span class="c1"># Suggested by Stefan Behnel.</span>
    311. <span class="c1"># Note that without the test for args[0], we get a problem because</span>
    312. <span class="c1"># during formatting, we test to see if the arg is present using</span>
    313. <span class="c1"># &#39;if self.args:&#39;. If the event being logged is e.g. &#39;Value is %d&#39;</span>
    314. <span class="c1"># and if the passed arg fails &#39;if self.args:&#39; then no formatting</span>
    315. <span class="c1"># is done. For example, logger.warning(&#39;Value is %d&#39;, 0) would log</span>
    316. <span class="c1"># &#39;Value is %d&#39; instead of &#39;Value is 0&#39;.</span>
    317. <span class="c1"># For the use case of passing a dictionary, this should not be a</span>
    318. <span class="c1"># problem.</span>
    319. <span class="c1"># Issue #21172: a request was made to relax the isinstance check</span>
    320. <span class="c1"># to hasattr(args[0], &#39;__getitem__&#39;). However, the docs on string</span>
    321. <span class="c1"># formatting still seem to suggest a mapping object is required.</span>
    322. <span class="c1"># Thus, while not removing the isinstance check, it does now look</span>
    323. <span class="c1"># for collections.abc.Mapping rather than, as before, dict.</span>
    324. <span class="k">if</span> <span class="p">(</span><span class="n">args</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">args</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Mapping</span><span class="p">)</span>
    325. <span class="ow">and</span> <span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
    326. <span class="n">args</span> <span class="o">=</span> <span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    327. <span class="bp">self</span><span class="o">.</span><span class="n">args</span> <span class="o">=</span> <span class="n">args</span>
    328. <span class="bp">self</span><span class="o">.</span><span class="n">levelname</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    329. <span class="bp">self</span><span class="o">.</span><span class="n">levelno</span> <span class="o">=</span> <span class="n">level</span>
    330. <span class="bp">self</span><span class="o">.</span><span class="n">pathname</span> <span class="o">=</span> <span class="n">pathname</span>
    331. <span class="k">try</span><span class="p">:</span>
    332. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">basename</span><span class="p">(</span><span class="n">pathname</span><span class="p">)</span>
    333. <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">splitext</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    334. <span class="k">except</span> <span class="p">(</span><span class="ne">TypeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">,</span> <span class="ne">AttributeError</span><span class="p">):</span>
    335. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">pathname</span>
    336. <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="s2">&quot;Unknown module&quot;</span>
    337. <span class="bp">self</span><span class="o">.</span><span class="n">exc_info</span> <span class="o">=</span> <span class="n">exc_info</span>
    338. <span class="bp">self</span><span class="o">.</span><span class="n">exc_text</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used to cache the traceback text</span>
    339. <span class="bp">self</span><span class="o">.</span><span class="n">stack_info</span> <span class="o">=</span> <span class="n">sinfo</span>
    340. <span class="bp">self</span><span class="o">.</span><span class="n">lineno</span> <span class="o">=</span> <span class="n">lineno</span>
    341. <span class="bp">self</span><span class="o">.</span><span class="n">funcName</span> <span class="o">=</span> <span class="n">func</span>
    342. <span class="bp">self</span><span class="o">.</span><span class="n">created</span> <span class="o">=</span> <span class="n">ct</span>
    343. <span class="bp">self</span><span class="o">.</span><span class="n">msecs</span> <span class="o">=</span> <span class="p">(</span><span class="n">ct</span> <span class="o">-</span> <span class="nb">int</span><span class="p">(</span><span class="n">ct</span><span class="p">))</span> <span class="o">*</span> <span class="mi">1000</span>
    344. <span class="bp">self</span><span class="o">.</span><span class="n">relativeCreated</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">created</span> <span class="o">-</span> <span class="n">_startTime</span><span class="p">)</span> <span class="o">*</span> <span class="mi">1000</span>
    345. <span class="k">if</span> <span class="n">logThreads</span><span class="p">:</span>
    346. <span class="bp">self</span><span class="o">.</span><span class="n">thread</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">get_ident</span><span class="p">()</span>
    347. <span class="bp">self</span><span class="o">.</span><span class="n">threadName</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">current_thread</span><span class="p">()</span><span class="o">.</span><span class="n">name</span>
    348. <span class="k">else</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
    349. <span class="bp">self</span><span class="o">.</span><span class="n">thread</span> <span class="o">=</span> <span class="kc">None</span>
    350. <span class="bp">self</span><span class="o">.</span><span class="n">threadName</span> <span class="o">=</span> <span class="kc">None</span>
    351. <span class="k">if</span> <span class="ow">not</span> <span class="n">logMultiprocessing</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
    352. <span class="bp">self</span><span class="o">.</span><span class="n">processName</span> <span class="o">=</span> <span class="kc">None</span>
    353. <span class="k">else</span><span class="p">:</span>
    354. <span class="bp">self</span><span class="o">.</span><span class="n">processName</span> <span class="o">=</span> <span class="s1">&#39;MainProcess&#39;</span>
    355. <span class="n">mp</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">modules</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;multiprocessing&#39;</span><span class="p">)</span>
    356. <span class="k">if</span> <span class="n">mp</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    357. <span class="c1"># Errors may occur if multiprocessing has not finished loading</span>
    358. <span class="c1"># yet - e.g. if a custom import hook causes third-party code</span>
    359. <span class="c1"># to run when multiprocessing calls import. See issue 8200</span>
    360. <span class="c1"># for an example</span>
    361. <span class="k">try</span><span class="p">:</span>
    362. <span class="bp">self</span><span class="o">.</span><span class="n">processName</span> <span class="o">=</span> <span class="n">mp</span><span class="o">.</span><span class="n">current_process</span><span class="p">()</span><span class="o">.</span><span class="n">name</span>
    363. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span> <span class="c1">#pragma: no cover</span>
    364. <span class="k">pass</span>
    365. <span class="k">if</span> <span class="n">logProcesses</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">os</span><span class="p">,</span> <span class="s1">&#39;getpid&#39;</span><span class="p">):</span>
    366. <span class="bp">self</span><span class="o">.</span><span class="n">process</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">getpid</span><span class="p">()</span>
    367. <span class="k">else</span><span class="p">:</span>
    368. <span class="bp">self</span><span class="o">.</span><span class="n">process</span> <span class="o">=</span> <span class="kc">None</span>
    369. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    370. <span class="k">return</span> <span class="s1">&#39;&lt;LogRecord: </span><span class="si">%s</span><span class="s1">, </span><span class="si">%s</span><span class="s1">, </span><span class="si">%s</span><span class="s1">, </span><span class="si">%s</span><span class="s1">, &quot;</span><span class="si">%s</span><span class="s1">&quot;&gt;&#39;</span><span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">levelno</span><span class="p">,</span>
    371. <span class="bp">self</span><span class="o">.</span><span class="n">pathname</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">lineno</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">msg</span><span class="p">)</span>
    372. <span class="k">def</span> <span class="nf">getMessage</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    373. <span class="sd">&quot;&quot;&quot;</span>
    374. <span class="sd"> Return the message for this LogRecord.</span>
    375. <span class="sd"> Return the message for this LogRecord after merging any user-supplied</span>
    376. <span class="sd"> arguments with the message.</span>
    377. <span class="sd"> &quot;&quot;&quot;</span>
    378. <span class="n">msg</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">msg</span><span class="p">)</span>
    379. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="p">:</span>
    380. <span class="n">msg</span> <span class="o">=</span> <span class="n">msg</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span>
    381. <span class="k">return</span> <span class="n">msg</span>
    382. <span class="c1">#</span>
    383. <span class="c1"># Determine which class to use when instantiating log records.</span>
    384. <span class="c1">#</span>
    385. <span class="n">_logRecordFactory</span> <span class="o">=</span> <span class="n">LogRecord</span>
    386. <span class="k">def</span> <span class="nf">setLogRecordFactory</span><span class="p">(</span><span class="n">factory</span><span class="p">):</span>
    387. <span class="sd">&quot;&quot;&quot;</span>
    388. <span class="sd"> Set the factory to be used when instantiating a log record.</span>
    389. <span class="sd"> :param factory: A callable which will be called to instantiate</span>
    390. <span class="sd"> a log record.</span>
    391. <span class="sd"> &quot;&quot;&quot;</span>
    392. <span class="k">global</span> <span class="n">_logRecordFactory</span>
    393. <span class="n">_logRecordFactory</span> <span class="o">=</span> <span class="n">factory</span>
    394. <span class="k">def</span> <span class="nf">getLogRecordFactory</span><span class="p">():</span>
    395. <span class="sd">&quot;&quot;&quot;</span>
    396. <span class="sd"> Return the factory to be used when instantiating a log record.</span>
    397. <span class="sd"> &quot;&quot;&quot;</span>
    398. <span class="k">return</span> <span class="n">_logRecordFactory</span>
    399. <span class="k">def</span> <span class="nf">makeLogRecord</span><span class="p">(</span><span class="nb">dict</span><span class="p">):</span>
    400. <span class="sd">&quot;&quot;&quot;</span>
    401. <span class="sd"> Make a LogRecord whose attributes are defined by the specified dictionary,</span>
    402. <span class="sd"> This function is useful for converting a logging event received over</span>
    403. <span class="sd"> a socket connection (which is sent as a dictionary) into a LogRecord</span>
    404. <span class="sd"> instance.</span>
    405. <span class="sd"> &quot;&quot;&quot;</span>
    406. <span class="n">rv</span> <span class="o">=</span> <span class="n">_logRecordFactory</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="p">(),</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    407. <span class="n">rv</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">)</span>
    408. <span class="k">return</span> <span class="n">rv</span>
    409. <span class="c1">#---------------------------------------------------------------------------</span>
    410. <span class="c1"># Formatter classes and functions</span>
    411. <span class="c1">#---------------------------------------------------------------------------</span>
    412. <span class="n">_str_formatter</span> <span class="o">=</span> <span class="n">StrFormatter</span><span class="p">()</span>
    413. <span class="k">del</span> <span class="n">StrFormatter</span>
    414. <span class="k">class</span> <span class="nc">PercentStyle</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    415. <span class="n">default_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">%(message)s</span><span class="s1">&#39;</span>
    416. <span class="n">asctime_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">%(asctime)s</span><span class="s1">&#39;</span>
    417. <span class="n">asctime_search</span> <span class="o">=</span> <span class="s1">&#39;%(asctime)&#39;</span>
    418. <span class="n">validation_pattern</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;%\(\w+\)[#0+ -]*(\*|\d+)?(\.(\*|\d+))?[diouxefgcrsa%]&#39;</span><span class="p">,</span> <span class="n">re</span><span class="o">.</span><span class="n">I</span><span class="p">)</span>
    419. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="p">):</span>
    420. <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">=</span> <span class="n">fmt</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_format</span>
    421. <span class="k">def</span> <span class="nf">usesTime</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    422. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">asctime_search</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span>
    423. <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    424. <span class="sd">&quot;&quot;&quot;Validate the input format, ensure it matches the correct style&quot;&quot;&quot;</span>
    425. <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">validation_pattern</span><span class="o">.</span><span class="n">search</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">):</span>
    426. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid format &#39;</span><span class="si">%s</span><span class="s2">&#39; for &#39;</span><span class="si">%s</span><span class="s2">&#39; style&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_format</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
    427. <span class="k">def</span> <span class="nf">_format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    428. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">%</span> <span class="n">record</span><span class="o">.</span><span class="vm">__dict__</span>
    429. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    430. <span class="k">try</span><span class="p">:</span>
    431. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    432. <span class="k">except</span> <span class="ne">KeyError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    433. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Formatting field not found in record: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">e</span><span class="p">)</span>
    434. <span class="k">class</span> <span class="nc">StrFormatStyle</span><span class="p">(</span><span class="n">PercentStyle</span><span class="p">):</span>
    435. <span class="n">default_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">{message}</span><span class="s1">&#39;</span>
    436. <span class="n">asctime_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">{asctime}</span><span class="s1">&#39;</span>
    437. <span class="n">asctime_search</span> <span class="o">=</span> <span class="s1">&#39;{asctime&#39;</span>
    438. <span class="n">fmt_spec</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;^(.?[&lt;&gt;=^])?[+ -]?#?0?(\d+|{\w+})?[,_]?(\.(\d+|{\w+}))?[bcdefgnosx%]?$&#39;</span><span class="p">,</span> <span class="n">re</span><span class="o">.</span><span class="n">I</span><span class="p">)</span>
    439. <span class="n">field_spec</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;^(\d+|\w+)(\.\w+|\[[^]]+\])*$&#39;</span><span class="p">)</span>
    440. <span class="k">def</span> <span class="nf">_format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    441. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="o">**</span><span class="n">record</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
    442. <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    443. <span class="sd">&quot;&quot;&quot;Validate the input format, ensure it is the correct string formatting style&quot;&quot;&quot;</span>
    444. <span class="n">fields</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
    445. <span class="k">try</span><span class="p">:</span>
    446. <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">fieldname</span><span class="p">,</span> <span class="n">spec</span><span class="p">,</span> <span class="n">conversion</span> <span class="ow">in</span> <span class="n">_str_formatter</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">):</span>
    447. <span class="k">if</span> <span class="n">fieldname</span><span class="p">:</span>
    448. <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">field_spec</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">fieldname</span><span class="p">):</span>
    449. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid field name/expression: </span><span class="si">%r</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">fieldname</span><span class="p">)</span>
    450. <span class="n">fields</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">fieldname</span><span class="p">)</span>
    451. <span class="k">if</span> <span class="n">conversion</span> <span class="ow">and</span> <span class="n">conversion</span> <span class="ow">not</span> <span class="ow">in</span> <span class="s1">&#39;rsa&#39;</span><span class="p">:</span>
    452. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid conversion: </span><span class="si">%r</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">conversion</span><span class="p">)</span>
    453. <span class="k">if</span> <span class="n">spec</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">fmt_spec</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">spec</span><span class="p">):</span>
    454. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;bad specifier: </span><span class="si">%r</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">spec</span><span class="p">)</span>
    455. <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    456. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">e</span><span class="p">)</span>
    457. <span class="k">if</span> <span class="ow">not</span> <span class="n">fields</span><span class="p">:</span>
    458. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: no fields&#39;</span><span class="p">)</span>
    459. <span class="k">class</span> <span class="nc">StringTemplateStyle</span><span class="p">(</span><span class="n">PercentStyle</span><span class="p">):</span>
    460. <span class="n">default_format</span> <span class="o">=</span> <span class="s1">&#39;$</span><span class="si">{message}</span><span class="s1">&#39;</span>
    461. <span class="n">asctime_format</span> <span class="o">=</span> <span class="s1">&#39;$</span><span class="si">{asctime}</span><span class="s1">&#39;</span>
    462. <span class="n">asctime_search</span> <span class="o">=</span> <span class="s1">&#39;$</span><span class="si">{asctime}</span><span class="s1">&#39;</span>
    463. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="p">):</span>
    464. <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">=</span> <span class="n">fmt</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_format</span>
    465. <span class="bp">self</span><span class="o">.</span><span class="n">_tpl</span> <span class="o">=</span> <span class="n">Template</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">)</span>
    466. <span class="k">def</span> <span class="nf">usesTime</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    467. <span class="n">fmt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span>
    468. <span class="k">return</span> <span class="n">fmt</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;$asctime&#39;</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">fmt</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">asctime_format</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span>
    469. <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    470. <span class="n">pattern</span> <span class="o">=</span> <span class="n">Template</span><span class="o">.</span><span class="n">pattern</span>
    471. <span class="n">fields</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
    472. <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">pattern</span><span class="o">.</span><span class="n">finditer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">):</span>
    473. <span class="n">d</span> <span class="o">=</span> <span class="n">m</span><span class="o">.</span><span class="n">groupdict</span><span class="p">()</span>
    474. <span class="k">if</span> <span class="n">d</span><span class="p">[</span><span class="s1">&#39;named&#39;</span><span class="p">]:</span>
    475. <span class="n">fields</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">d</span><span class="p">[</span><span class="s1">&#39;named&#39;</span><span class="p">])</span>
    476. <span class="k">elif</span> <span class="n">d</span><span class="p">[</span><span class="s1">&#39;braced&#39;</span><span class="p">]:</span>
    477. <span class="n">fields</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">d</span><span class="p">[</span><span class="s1">&#39;braced&#39;</span><span class="p">])</span>
    478. <span class="k">elif</span> <span class="n">m</span><span class="o">.</span><span class="n">group</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="s1">&#39;$&#39;</span><span class="p">:</span>
    479. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: bare </span><span class="se">\&#39;</span><span class="s1">$</span><span class="se">\&#39;</span><span class="s1"> not allowed&#39;</span><span class="p">)</span>
    480. <span class="k">if</span> <span class="ow">not</span> <span class="n">fields</span><span class="p">:</span>
    481. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: no fields&#39;</span><span class="p">)</span>
    482. <span class="k">def</span> <span class="nf">_format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    483. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tpl</span><span class="o">.</span><span class="n">substitute</span><span class="p">(</span><span class="o">**</span><span class="n">record</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
    484. <span class="n">BASIC_FORMAT</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">%(levelname)s</span><span class="s2">:</span><span class="si">%(name)s</span><span class="s2">:</span><span class="si">%(message)s</span><span class="s2">&quot;</span>
    485. <span class="n">_STYLES</span> <span class="o">=</span> <span class="p">{</span>
    486. <span class="s1">&#39;%&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">PercentStyle</span><span class="p">,</span> <span class="n">BASIC_FORMAT</span><span class="p">),</span>
    487. <span class="s1">&#39;{&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">StrFormatStyle</span><span class="p">,</span> <span class="s1">&#39;</span><span class="si">{levelname}</span><span class="s1">:</span><span class="si">{name}</span><span class="s1">:</span><span class="si">{message}</span><span class="s1">&#39;</span><span class="p">),</span>
    488. <span class="s1">&#39;$&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">StringTemplateStyle</span><span class="p">,</span> <span class="s1">&#39;$</span><span class="si">{levelname}</span><span class="s1">:$</span><span class="si">{name}</span><span class="s1">:$</span><span class="si">{message}</span><span class="s1">&#39;</span><span class="p">),</span>
    489. <span class="p">}</span>
    490. <span class="k">class</span> <span class="nc">Formatter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    491. <span class="sd">&quot;&quot;&quot;</span>
    492. <span class="sd"> Formatter instances are used to convert a LogRecord to text.</span>
    493. <span class="sd"> Formatters need to know how a LogRecord is constructed. They are</span>
    494. <span class="sd"> responsible for converting a LogRecord to (usually) a string which can</span>
    495. <span class="sd"> be interpreted by either a human or an external system. The base Formatter</span>
    496. <span class="sd"> allows a formatting string to be specified. If none is supplied, the</span>
    497. <span class="sd"> style-dependent default value, &quot;%(message)s&quot;, &quot;{message}&quot;, or</span>
    498. <span class="sd"> &quot;${message}&quot;, is used.</span>
    499. <span class="sd"> The Formatter can be initialized with a format string which makes use of</span>
    500. <span class="sd"> knowledge of the LogRecord attributes - e.g. the default value mentioned</span>
    501. <span class="sd"> above makes use of the fact that the user&#39;s message and arguments are pre-</span>
    502. <span class="sd"> formatted into a LogRecord&#39;s message attribute. Currently, the useful</span>
    503. <span class="sd"> attributes in a LogRecord are described by:</span>
    504. <span class="sd"> %(name)s Name of the logger (logging channel)</span>
    505. <span class="sd"> %(levelno)s Numeric logging level for the message (DEBUG, INFO,</span>
    506. <span class="sd"> WARNING, ERROR, CRITICAL)</span>
    507. <span class="sd"> %(levelname)s Text logging level for the message (&quot;DEBUG&quot;, &quot;INFO&quot;,</span>
    508. <span class="sd"> &quot;WARNING&quot;, &quot;ERROR&quot;, &quot;CRITICAL&quot;)</span>
    509. <span class="sd"> %(pathname)s Full pathname of the source file where the logging</span>
    510. <span class="sd"> call was issued (if available)</span>
    511. <span class="sd"> %(filename)s Filename portion of pathname</span>
    512. <span class="sd"> %(module)s Module (name portion of filename)</span>
    513. <span class="sd"> %(lineno)d Source line number where the logging call was issued</span>
    514. <span class="sd"> (if available)</span>
    515. <span class="sd"> %(funcName)s Function name</span>
    516. <span class="sd"> %(created)f Time when the LogRecord was created (time.time()</span>
    517. <span class="sd"> return value)</span>
    518. <span class="sd"> %(asctime)s Textual time when the LogRecord was created</span>
    519. <span class="sd"> %(msecs)d Millisecond portion of the creation time</span>
    520. <span class="sd"> %(relativeCreated)d Time in milliseconds when the LogRecord was created,</span>
    521. <span class="sd"> relative to the time the logging module was loaded</span>
    522. <span class="sd"> (typically at application startup time)</span>
    523. <span class="sd"> %(thread)d Thread ID (if available)</span>
    524. <span class="sd"> %(threadName)s Thread name (if available)</span>
    525. <span class="sd"> %(process)d Process ID (if available)</span>
    526. <span class="sd"> %(message)s The result of record.getMessage(), computed just as</span>
    527. <span class="sd"> the record is emitted</span>
    528. <span class="sd"> &quot;&quot;&quot;</span>
    529. <span class="n">converter</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">localtime</span>
    530. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">datefmt</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">style</span><span class="o">=</span><span class="s1">&#39;%&#39;</span><span class="p">,</span> <span class="n">validate</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    531. <span class="sd">&quot;&quot;&quot;</span>
    532. <span class="sd"> Initialize the formatter with specified format strings.</span>
    533. <span class="sd"> Initialize the formatter either with the specified format string, or a</span>
    534. <span class="sd"> default as described above. Allow for specialized date formatting with</span>
    535. <span class="sd"> the optional datefmt argument. If datefmt is omitted, you get an</span>
    536. <span class="sd"> ISO8601-like (or RFC 3339-like) format.</span>
    537. <span class="sd"> Use a style parameter of &#39;%&#39;, &#39;{&#39; or &#39;$&#39; to specify that you want to</span>
    538. <span class="sd"> use one of %-formatting, :meth:`str.format` (``{}``) formatting or</span>
    539. <span class="sd"> :class:`string.Template` formatting in your format string.</span>
    540. <span class="sd"> .. versionchanged:: 3.2</span>
    541. <span class="sd"> Added the ``style`` parameter.</span>
    542. <span class="sd"> &quot;&quot;&quot;</span>
    543. <span class="k">if</span> <span class="n">style</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_STYLES</span><span class="p">:</span>
    544. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Style must be one of: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="s1">&#39;,&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
    545. <span class="n">_STYLES</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
    546. <span class="bp">self</span><span class="o">.</span><span class="n">_style</span> <span class="o">=</span> <span class="n">_STYLES</span><span class="p">[</span><span class="n">style</span><span class="p">][</span><span class="mi">0</span><span class="p">](</span><span class="n">fmt</span><span class="p">)</span>
    547. <span class="k">if</span> <span class="n">validate</span><span class="p">:</span>
    548. <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">validate</span><span class="p">()</span>
    549. <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">_fmt</span>
    550. <span class="bp">self</span><span class="o">.</span><span class="n">datefmt</span> <span class="o">=</span> <span class="n">datefmt</span>
    551. <span class="n">default_time_format</span> <span class="o">=</span> <span class="s1">&#39;%Y-%m-</span><span class="si">%d</span><span class="s1"> %H:%M:%S&#39;</span>
    552. <span class="n">default_msec_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">%s</span><span class="s1">,</span><span class="si">%03d</span><span class="s1">&#39;</span>
    553. <span class="k">def</span> <span class="nf">formatTime</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">,</span> <span class="n">datefmt</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    554. <span class="sd">&quot;&quot;&quot;</span>
    555. <span class="sd"> Return the creation time of the specified LogRecord as formatted text.</span>
    556. <span class="sd"> This method should be called from format() by a formatter which</span>
    557. <span class="sd"> wants to make use of a formatted time. This method can be overridden</span>
    558. <span class="sd"> in formatters to provide for any specific requirement, but the</span>
    559. <span class="sd"> basic behaviour is as follows: if datefmt (a string) is specified,</span>
    560. <span class="sd"> it is used with time.strftime() to format the creation time of the</span>
    561. <span class="sd"> record. Otherwise, an ISO8601-like (or RFC 3339-like) format is used.</span>
    562. <span class="sd"> The resulting string is returned. This function uses a user-configurable</span>
    563. <span class="sd"> function to convert the creation time to a tuple. By default,</span>
    564. <span class="sd"> time.localtime() is used; to change this for a particular formatter</span>
    565. <span class="sd"> instance, set the &#39;converter&#39; attribute to a function with the same</span>
    566. <span class="sd"> signature as time.localtime() or time.gmtime(). To change it for all</span>
    567. <span class="sd"> formatters, for example if you want all logging times to be shown in GMT,</span>
    568. <span class="sd"> set the &#39;converter&#39; attribute in the Formatter class.</span>
    569. <span class="sd"> &quot;&quot;&quot;</span>
    570. <span class="n">ct</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">converter</span><span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">created</span><span class="p">)</span>
    571. <span class="k">if</span> <span class="n">datefmt</span><span class="p">:</span>
    572. <span class="n">s</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="n">datefmt</span><span class="p">,</span> <span class="n">ct</span><span class="p">)</span>
    573. <span class="k">else</span><span class="p">:</span>
    574. <span class="n">s</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">default_time_format</span><span class="p">,</span> <span class="n">ct</span><span class="p">)</span>
    575. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_msec_format</span><span class="p">:</span>
    576. <span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_msec_format</span> <span class="o">%</span> <span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">record</span><span class="o">.</span><span class="n">msecs</span><span class="p">)</span>
    577. <span class="k">return</span> <span class="n">s</span>
    578. <span class="k">def</span> <span class="nf">formatException</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ei</span><span class="p">):</span>
    579. <span class="sd">&quot;&quot;&quot;</span>
    580. <span class="sd"> Format and return the specified exception information as a string.</span>
    581. <span class="sd"> This default implementation just uses</span>
    582. <span class="sd"> traceback.print_exception()</span>
    583. <span class="sd"> &quot;&quot;&quot;</span>
    584. <span class="n">sio</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">StringIO</span><span class="p">()</span>
    585. <span class="n">tb</span> <span class="o">=</span> <span class="n">ei</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
    586. <span class="c1"># See issues #9427, #1553375. Commented out for now.</span>
    587. <span class="c1">#if getattr(self, &#39;fullstack&#39;, False):</span>
    588. <span class="c1"># traceback.print_stack(tb.tb_frame.f_back, file=sio)</span>
    589. <span class="n">traceback</span><span class="o">.</span><span class="n">print_exception</span><span class="p">(</span><span class="n">ei</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ei</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">tb</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">sio</span><span class="p">)</span>
    590. <span class="n">s</span> <span class="o">=</span> <span class="n">sio</span><span class="o">.</span><span class="n">getvalue</span><span class="p">()</span>
    591. <span class="n">sio</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    592. <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">==</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">:</span>
    593. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    594. <span class="k">return</span> <span class="n">s</span>
    595. <span class="k">def</span> <span class="nf">usesTime</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    596. <span class="sd">&quot;&quot;&quot;</span>
    597. <span class="sd"> Check if the format uses the creation time of the record.</span>
    598. <span class="sd"> &quot;&quot;&quot;</span>
    599. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">usesTime</span><span class="p">()</span>
    600. <span class="k">def</span> <span class="nf">formatMessage</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    601. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    602. <span class="k">def</span> <span class="nf">formatStack</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stack_info</span><span class="p">):</span>
    603. <span class="sd">&quot;&quot;&quot;</span>
    604. <span class="sd"> This method is provided as an extension point for specialized</span>
    605. <span class="sd"> formatting of stack information.</span>
    606. <span class="sd"> The input data is a string as returned from a call to</span>
    607. <span class="sd"> :func:`traceback.print_stack`, but with the last trailing newline</span>
    608. <span class="sd"> removed.</span>
    609. <span class="sd"> The base implementation just returns the value passed in.</span>
    610. <span class="sd"> &quot;&quot;&quot;</span>
    611. <span class="k">return</span> <span class="n">stack_info</span>
    612. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    613. <span class="sd">&quot;&quot;&quot;</span>
    614. <span class="sd"> Format the specified record as text.</span>
    615. <span class="sd"> The record&#39;s attribute dictionary is used as the operand to a</span>
    616. <span class="sd"> string formatting operation which yields the returned string.</span>
    617. <span class="sd"> Before formatting the dictionary, a couple of preparatory steps</span>
    618. <span class="sd"> are carried out. The message attribute of the record is computed</span>
    619. <span class="sd"> using LogRecord.getMessage(). If the formatting string uses the</span>
    620. <span class="sd"> time (as determined by a call to usesTime(), formatTime() is</span>
    621. <span class="sd"> called to format the event time. If there is exception information,</span>
    622. <span class="sd"> it is formatted using formatException() and appended to the message.</span>
    623. <span class="sd"> &quot;&quot;&quot;</span>
    624. <span class="n">record</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="n">record</span><span class="o">.</span><span class="n">getMessage</span><span class="p">()</span>
    625. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">usesTime</span><span class="p">():</span>
    626. <span class="n">record</span><span class="o">.</span><span class="n">asctime</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatTime</span><span class="p">(</span><span class="n">record</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">datefmt</span><span class="p">)</span>
    627. <span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatMessage</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    628. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_info</span><span class="p">:</span>
    629. <span class="c1"># Cache the traceback text to avoid converting it multiple times</span>
    630. <span class="c1"># (it&#39;s constant anyway)</span>
    631. <span class="k">if</span> <span class="ow">not</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span><span class="p">:</span>
    632. <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatException</span><span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">exc_info</span><span class="p">)</span>
    633. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span><span class="p">:</span>
    634. <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">!=</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">:</span>
    635. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span>
    636. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span>
    637. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">stack_info</span><span class="p">:</span>
    638. <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">!=</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">:</span>
    639. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span>
    640. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatStack</span><span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">stack_info</span><span class="p">)</span>
    641. <span class="k">return</span> <span class="n">s</span>
    642. <span class="c1">#</span>
    643. <span class="c1"># The default formatter to use when no other is specified</span>
    644. <span class="c1">#</span>
    645. <span class="n">_defaultFormatter</span> <span class="o">=</span> <span class="n">Formatter</span><span class="p">()</span>
    646. <span class="k">class</span> <span class="nc">BufferingFormatter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    647. <span class="sd">&quot;&quot;&quot;</span>
    648. <span class="sd"> A formatter suitable for formatting a number of records.</span>
    649. <span class="sd"> &quot;&quot;&quot;</span>
    650. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">linefmt</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    651. <span class="sd">&quot;&quot;&quot;</span>
    652. <span class="sd"> Optionally specify a formatter which will be used to format each</span>
    653. <span class="sd"> individual record.</span>
    654. <span class="sd"> &quot;&quot;&quot;</span>
    655. <span class="k">if</span> <span class="n">linefmt</span><span class="p">:</span>
    656. <span class="bp">self</span><span class="o">.</span><span class="n">linefmt</span> <span class="o">=</span> <span class="n">linefmt</span>
    657. <span class="k">else</span><span class="p">:</span>
    658. <span class="bp">self</span><span class="o">.</span><span class="n">linefmt</span> <span class="o">=</span> <span class="n">_defaultFormatter</span>
    659. <span class="k">def</span> <span class="nf">formatHeader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">records</span><span class="p">):</span>
    660. <span class="sd">&quot;&quot;&quot;</span>
    661. <span class="sd"> Return the header string for the specified records.</span>
    662. <span class="sd"> &quot;&quot;&quot;</span>
    663. <span class="k">return</span> <span class="s2">&quot;&quot;</span>
    664. <span class="k">def</span> <span class="nf">formatFooter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">records</span><span class="p">):</span>
    665. <span class="sd">&quot;&quot;&quot;</span>
    666. <span class="sd"> Return the footer string for the specified records.</span>
    667. <span class="sd"> &quot;&quot;&quot;</span>
    668. <span class="k">return</span> <span class="s2">&quot;&quot;</span>
    669. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">records</span><span class="p">):</span>
    670. <span class="sd">&quot;&quot;&quot;</span>
    671. <span class="sd"> Format the specified records and return the result as a string.</span>
    672. <span class="sd"> &quot;&quot;&quot;</span>
    673. <span class="n">rv</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
    674. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">records</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    675. <span class="n">rv</span> <span class="o">=</span> <span class="n">rv</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatHeader</span><span class="p">(</span><span class="n">records</span><span class="p">)</span>
    676. <span class="k">for</span> <span class="n">record</span> <span class="ow">in</span> <span class="n">records</span><span class="p">:</span>
    677. <span class="n">rv</span> <span class="o">=</span> <span class="n">rv</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">linefmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    678. <span class="n">rv</span> <span class="o">=</span> <span class="n">rv</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatFooter</span><span class="p">(</span><span class="n">records</span><span class="p">)</span>
    679. <span class="k">return</span> <span class="n">rv</span>
    680. <span class="c1">#---------------------------------------------------------------------------</span>
    681. <span class="c1"># Filter classes and functions</span>
    682. <span class="c1">#---------------------------------------------------------------------------</span>
    683. <span class="k">class</span> <span class="nc">Filter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    684. <span class="sd">&quot;&quot;&quot;</span>
    685. <span class="sd"> Filter instances are used to perform arbitrary filtering of LogRecords.</span>
    686. <span class="sd"> Loggers and Handlers can optionally use Filter instances to filter</span>
    687. <span class="sd"> records as desired. The base filter class only allows events which are</span>
    688. <span class="sd"> below a certain point in the logger hierarchy. For example, a filter</span>
    689. <span class="sd"> initialized with &quot;A.B&quot; will allow events logged by loggers &quot;A.B&quot;,</span>
    690. <span class="sd"> &quot;A.B.C&quot;, &quot;A.B.C.D&quot;, &quot;A.B.D&quot; etc. but not &quot;A.BB&quot;, &quot;B.A.B&quot; etc. If</span>
    691. <span class="sd"> initialized with the empty string, all events are passed.</span>
    692. <span class="sd"> &quot;&quot;&quot;</span>
    693. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;&#39;</span><span class="p">):</span>
    694. <span class="sd">&quot;&quot;&quot;</span>
    695. <span class="sd"> Initialize a filter.</span>
    696. <span class="sd"> Initialize with the name of the logger which, together with its</span>
    697. <span class="sd"> children, will have its events allowed through the filter. If no</span>
    698. <span class="sd"> name is specified, allow every event.</span>
    699. <span class="sd"> &quot;&quot;&quot;</span>
    700. <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
    701. <span class="bp">self</span><span class="o">.</span><span class="n">nlen</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
    702. <span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    703. <span class="sd">&quot;&quot;&quot;</span>
    704. <span class="sd"> Determine if the specified record is to be logged.</span>
    705. <span class="sd"> Returns True if the record should be logged, or False otherwise.</span>
    706. <span class="sd"> If deemed appropriate, the record may be modified in-place.</span>
    707. <span class="sd"> &quot;&quot;&quot;</span>
    708. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">nlen</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    709. <span class="k">return</span> <span class="kc">True</span>
    710. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="n">record</span><span class="o">.</span><span class="n">name</span><span class="p">:</span>
    711. <span class="k">return</span> <span class="kc">True</span>
    712. <span class="k">elif</span> <span class="n">record</span><span class="o">.</span><span class="n">name</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nlen</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
    713. <span class="k">return</span> <span class="kc">False</span>
    714. <span class="k">return</span> <span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">name</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">nlen</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;.&quot;</span><span class="p">)</span>
    715. <span class="k">class</span> <span class="nc">Filterer</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    716. <span class="sd">&quot;&quot;&quot;</span>
    717. <span class="sd"> A base class for loggers and handlers which allows them to share</span>
    718. <span class="sd"> common code.</span>
    719. <span class="sd"> &quot;&quot;&quot;</span>
    720. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    721. <span class="sd">&quot;&quot;&quot;</span>
    722. <span class="sd"> Initialize the list of filters to be an empty list.</span>
    723. <span class="sd"> &quot;&quot;&quot;</span>
    724. <span class="bp">self</span><span class="o">.</span><span class="n">filters</span> <span class="o">=</span> <span class="p">[]</span>
    725. <span class="k">def</span> <span class="nf">addFilter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">filter</span><span class="p">):</span>
    726. <span class="sd">&quot;&quot;&quot;</span>
    727. <span class="sd"> Add the specified filter to this handler.</span>
    728. <span class="sd"> &quot;&quot;&quot;</span>
    729. <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="nb">filter</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">):</span>
    730. <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">filter</span><span class="p">)</span>
    731. <span class="k">def</span> <span class="nf">removeFilter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">filter</span><span class="p">):</span>
    732. <span class="sd">&quot;&quot;&quot;</span>
    733. <span class="sd"> Remove the specified filter from this handler.</span>
    734. <span class="sd"> &quot;&quot;&quot;</span>
    735. <span class="k">if</span> <span class="nb">filter</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">:</span>
    736. <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="nb">filter</span><span class="p">)</span>
    737. <span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    738. <span class="sd">&quot;&quot;&quot;</span>
    739. <span class="sd"> Determine if a record is loggable by consulting all the filters.</span>
    740. <span class="sd"> The default is to allow the record to be logged; any filter can veto</span>
    741. <span class="sd"> this and the record is then dropped. Returns a zero value if a record</span>
    742. <span class="sd"> is to be dropped, else non-zero.</span>
    743. <span class="sd"> .. versionchanged:: 3.2</span>
    744. <span class="sd"> Allow filters to be just callables.</span>
    745. <span class="sd"> &quot;&quot;&quot;</span>
    746. <span class="n">rv</span> <span class="o">=</span> <span class="kc">True</span>
    747. <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">:</span>
    748. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="s1">&#39;filter&#39;</span><span class="p">):</span>
    749. <span class="n">result</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    750. <span class="k">else</span><span class="p">:</span>
    751. <span class="n">result</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">record</span><span class="p">)</span> <span class="c1"># assume callable - will raise if not</span>
    752. <span class="k">if</span> <span class="ow">not</span> <span class="n">result</span><span class="p">:</span>
    753. <span class="n">rv</span> <span class="o">=</span> <span class="kc">False</span>
    754. <span class="k">break</span>
    755. <span class="k">return</span> <span class="n">rv</span>
    756. <span class="c1">#---------------------------------------------------------------------------</span>
    757. <span class="c1"># Handler classes and functions</span>
    758. <span class="c1">#---------------------------------------------------------------------------</span>
    759. <span class="n">_handlers</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">WeakValueDictionary</span><span class="p">()</span> <span class="c1">#map of handler names to handlers</span>
    760. <span class="n">_handlerList</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># added to allow handlers to be removed in reverse of order initialized</span>
    761. <span class="k">def</span> <span class="nf">_removeHandlerRef</span><span class="p">(</span><span class="n">wr</span><span class="p">):</span>
    762. <span class="sd">&quot;&quot;&quot;</span>
    763. <span class="sd"> Remove a handler reference from the internal cleanup list.</span>
    764. <span class="sd"> &quot;&quot;&quot;</span>
    765. <span class="c1"># This function can be called during module teardown, when globals are</span>
    766. <span class="c1"># set to None. It can also be called from another thread. So we need to</span>
    767. <span class="c1"># pre-emptively grab the necessary globals and check if they&#39;re None,</span>
    768. <span class="c1"># to prevent race conditions and failures during interpreter shutdown.</span>
    769. <span class="n">acquire</span><span class="p">,</span> <span class="n">release</span><span class="p">,</span> <span class="n">handlers</span> <span class="o">=</span> <span class="n">_acquireLock</span><span class="p">,</span> <span class="n">_releaseLock</span><span class="p">,</span> <span class="n">_handlerList</span>
    770. <span class="k">if</span> <span class="n">acquire</span> <span class="ow">and</span> <span class="n">release</span> <span class="ow">and</span> <span class="n">handlers</span><span class="p">:</span>
    771. <span class="n">acquire</span><span class="p">()</span>
    772. <span class="k">try</span><span class="p">:</span>
    773. <span class="k">if</span> <span class="n">wr</span> <span class="ow">in</span> <span class="n">handlers</span><span class="p">:</span>
    774. <span class="n">handlers</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">wr</span><span class="p">)</span>
    775. <span class="k">finally</span><span class="p">:</span>
    776. <span class="n">release</span><span class="p">()</span>
    777. <span class="k">def</span> <span class="nf">_addHandlerRef</span><span class="p">(</span><span class="n">handler</span><span class="p">):</span>
    778. <span class="sd">&quot;&quot;&quot;</span>
    779. <span class="sd"> Add a handler to the internal cleanup list using a weak reference.</span>
    780. <span class="sd"> &quot;&quot;&quot;</span>
    781. <span class="n">_acquireLock</span><span class="p">()</span>
    782. <span class="k">try</span><span class="p">:</span>
    783. <span class="n">_handlerList</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">handler</span><span class="p">,</span> <span class="n">_removeHandlerRef</span><span class="p">))</span>
    784. <span class="k">finally</span><span class="p">:</span>
    785. <span class="n">_releaseLock</span><span class="p">()</span>
    786. <span class="k">class</span> <span class="nc">Handler</span><span class="p">(</span><span class="n">Filterer</span><span class="p">):</span>
    787. <span class="sd">&quot;&quot;&quot;</span>
    788. <span class="sd"> Handler instances dispatch logging events to specific destinations.</span>
    789. <span class="sd"> The base handler class. Acts as a placeholder which defines the Handler</span>
    790. <span class="sd"> interface. Handlers can optionally use Formatter instances to format</span>
    791. <span class="sd"> records as desired. By default, no formatter is specified; in this case,</span>
    792. <span class="sd"> the &#39;raw&#39; message as determined by record.message is logged.</span>
    793. <span class="sd"> &quot;&quot;&quot;</span>
    794. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="n">NOTSET</span><span class="p">):</span>
    795. <span class="sd">&quot;&quot;&quot;</span>
    796. <span class="sd"> Initializes the instance - basically setting the formatter to None</span>
    797. <span class="sd"> and the filter list to empty.</span>
    798. <span class="sd"> &quot;&quot;&quot;</span>
    799. <span class="n">Filterer</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
    800. <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="kc">None</span>
    801. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    802. <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span> <span class="o">=</span> <span class="kc">None</span>
    803. <span class="c1"># Add the handler to the global _handlerList (for cleanup on shutdown)</span>
    804. <span class="n">_addHandlerRef</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
    805. <span class="bp">self</span><span class="o">.</span><span class="n">createLock</span><span class="p">()</span>
    806. <span class="k">def</span> <span class="nf">get_name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    807. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span>
    808. <span class="k">def</span> <span class="nf">set_name</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
    809. <span class="n">_acquireLock</span><span class="p">()</span>
    810. <span class="k">try</span><span class="p">:</span>
    811. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="ow">in</span> <span class="n">_handlers</span><span class="p">:</span>
    812. <span class="k">del</span> <span class="n">_handlers</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_name</span><span class="p">]</span>
    813. <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
    814. <span class="k">if</span> <span class="n">name</span><span class="p">:</span>
    815. <span class="n">_handlers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span>
    816. <span class="k">finally</span><span class="p">:</span>
    817. <span class="n">_releaseLock</span><span class="p">()</span>
    818. <span class="n">name</span> <span class="o">=</span> <span class="nb">property</span><span class="p">(</span><span class="n">get_name</span><span class="p">,</span> <span class="n">set_name</span><span class="p">)</span>
    819. <span class="k">def</span> <span class="nf">createLock</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    820. <span class="sd">&quot;&quot;&quot;</span>
    821. <span class="sd"> Acquire a thread lock for serializing access to the underlying I/O.</span>
    822. <span class="sd"> &quot;&quot;&quot;</span>
    823. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">RLock</span><span class="p">()</span>
    824. <span class="n">_register_at_fork_reinit_lock</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
    825. <span class="k">def</span> <span class="nf">_at_fork_reinit</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    826. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="o">.</span><span class="n">_at_fork_reinit</span><span class="p">()</span>
    827. <span class="k">def</span> <span class="nf">acquire</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    828. <span class="sd">&quot;&quot;&quot;</span>
    829. <span class="sd"> Acquire the I/O thread lock.</span>
    830. <span class="sd"> &quot;&quot;&quot;</span>
    831. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
    832. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
    833. <span class="k">def</span> <span class="nf">release</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    834. <span class="sd">&quot;&quot;&quot;</span>
    835. <span class="sd"> Release the I/O thread lock.</span>
    836. <span class="sd"> &quot;&quot;&quot;</span>
    837. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
    838. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
    839. <span class="k">def</span> <span class="nf">setLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
    840. <span class="sd">&quot;&quot;&quot;</span>
    841. <span class="sd"> Set the logging level of this handler. level must be an int or a str.</span>
    842. <span class="sd"> &quot;&quot;&quot;</span>
    843. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    844. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    845. <span class="sd">&quot;&quot;&quot;</span>
    846. <span class="sd"> Format the specified record.</span>
    847. <span class="sd"> If a formatter is set, use it. Otherwise, use the default formatter</span>
    848. <span class="sd"> for the module.</span>
    849. <span class="sd"> &quot;&quot;&quot;</span>
    850. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span><span class="p">:</span>
    851. <span class="n">fmt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span>
    852. <span class="k">else</span><span class="p">:</span>
    853. <span class="n">fmt</span> <span class="o">=</span> <span class="n">_defaultFormatter</span>
    854. <span class="k">return</span> <span class="n">fmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    855. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    856. <span class="sd">&quot;&quot;&quot;</span>
    857. <span class="sd"> Do whatever it takes to actually log the specified logging record.</span>
    858. <span class="sd"> This version is intended to be implemented by subclasses and so</span>
    859. <span class="sd"> raises a NotImplementedError.</span>
    860. <span class="sd"> &quot;&quot;&quot;</span>
    861. <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s1">&#39;emit must be implemented &#39;</span>
    862. <span class="s1">&#39;by Handler subclasses&#39;</span><span class="p">)</span>
    863. <span class="k">def</span> <span class="nf">handle</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    864. <span class="sd">&quot;&quot;&quot;</span>
    865. <span class="sd"> Conditionally emit the specified logging record.</span>
    866. <span class="sd"> Emission depends on filters which may have been added to the handler.</span>
    867. <span class="sd"> Wrap the actual emission of the record with acquisition/release of</span>
    868. <span class="sd"> the I/O thread lock. Returns whether the filter passed the record for</span>
    869. <span class="sd"> emission.</span>
    870. <span class="sd"> &quot;&quot;&quot;</span>
    871. <span class="n">rv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    872. <span class="k">if</span> <span class="n">rv</span><span class="p">:</span>
    873. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
    874. <span class="k">try</span><span class="p">:</span>
    875. <span class="bp">self</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    876. <span class="k">finally</span><span class="p">:</span>
    877. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
    878. <span class="k">return</span> <span class="n">rv</span>
    879. <span class="k">def</span> <span class="nf">setFormatter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="p">):</span>
    880. <span class="sd">&quot;&quot;&quot;</span>
    881. <span class="sd"> Set the formatter for this handler.</span>
    882. <span class="sd"> &quot;&quot;&quot;</span>
    883. <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span> <span class="o">=</span> <span class="n">fmt</span>
    884. <span class="k">def</span> <span class="nf">flush</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    885. <span class="sd">&quot;&quot;&quot;</span>
    886. <span class="sd"> Ensure all logging output has been flushed.</span>
    887. <span class="sd"> This version does nothing and is intended to be implemented by</span>
    888. <span class="sd"> subclasses.</span>
    889. <span class="sd"> &quot;&quot;&quot;</span>
    890. <span class="k">pass</span>
    891. <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    892. <span class="sd">&quot;&quot;&quot;</span>
    893. <span class="sd"> Tidy up any resources used by the handler.</span>
    894. <span class="sd"> This version removes the handler from an internal map of handlers,</span>
    895. <span class="sd"> _handlers, which is used for handler lookup by name. Subclasses</span>
    896. <span class="sd"> should ensure that this gets called from overridden close()</span>
    897. <span class="sd"> methods.</span>
    898. <span class="sd"> &quot;&quot;&quot;</span>
    899. <span class="c1">#get the module data lock, as we&#39;re updating a shared structure.</span>
    900. <span class="n">_acquireLock</span><span class="p">()</span>
    901. <span class="k">try</span><span class="p">:</span> <span class="c1">#unlikely to raise an exception, but you never know...</span>
    902. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="ow">in</span> <span class="n">_handlers</span><span class="p">:</span>
    903. <span class="k">del</span> <span class="n">_handlers</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_name</span><span class="p">]</span>
    904. <span class="k">finally</span><span class="p">:</span>
    905. <span class="n">_releaseLock</span><span class="p">()</span>
    906. <span class="k">def</span> <span class="nf">handleError</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    907. <span class="sd">&quot;&quot;&quot;</span>
    908. <span class="sd"> Handle errors which occur during an emit() call.</span>
    909. <span class="sd"> This method should be called from handlers when an exception is</span>
    910. <span class="sd"> encountered during an emit() call. If raiseExceptions is false,</span>
    911. <span class="sd"> exceptions get silently ignored. This is what is mostly wanted</span>
    912. <span class="sd"> for a logging system - most users will not care about errors in</span>
    913. <span class="sd"> the logging system, they are more interested in application errors.</span>
    914. <span class="sd"> You could, however, replace this with a custom handler if you wish.</span>
    915. <span class="sd"> The record which was being processed is passed in to this method.</span>
    916. <span class="sd"> &quot;&quot;&quot;</span>
    917. <span class="k">if</span> <span class="n">raiseExceptions</span> <span class="ow">and</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="p">:</span> <span class="c1"># see issue 13807</span>
    918. <span class="n">t</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">tb</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()</span>
    919. <span class="k">try</span><span class="p">:</span>
    920. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;--- Logging error ---</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    921. <span class="n">traceback</span><span class="o">.</span><span class="n">print_exception</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">tb</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="p">)</span>
    922. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Call stack:</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    923. <span class="c1"># Walk the stack frame up until we&#39;re out of logging,</span>
    924. <span class="c1"># so as to print the calling context.</span>
    925. <span class="n">frame</span> <span class="o">=</span> <span class="n">tb</span><span class="o">.</span><span class="n">tb_frame</span>
    926. <span class="k">while</span> <span class="p">(</span><span class="n">frame</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">frame</span><span class="o">.</span><span class="n">f_code</span><span class="o">.</span><span class="n">co_filename</span><span class="p">)</span> <span class="o">==</span>
    927. <span class="n">__path__</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
    928. <span class="n">frame</span> <span class="o">=</span> <span class="n">frame</span><span class="o">.</span><span class="n">f_back</span>
    929. <span class="k">if</span> <span class="n">frame</span><span class="p">:</span>
    930. <span class="n">traceback</span><span class="o">.</span><span class="n">print_stack</span><span class="p">(</span><span class="n">frame</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="p">)</span>
    931. <span class="k">else</span><span class="p">:</span>
    932. <span class="c1"># couldn&#39;t find the right stack frame, for some reason</span>
    933. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Logged from file </span><span class="si">%s</span><span class="s1">, line </span><span class="si">%s</span><span class="se">\n</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="p">(</span>
    934. <span class="n">record</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">record</span><span class="o">.</span><span class="n">lineno</span><span class="p">))</span>
    935. <span class="c1"># Issue 18671: output logging message and arguments</span>
    936. <span class="k">try</span><span class="p">:</span>
    937. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Message: </span><span class="si">%r</span><span class="se">\n</span><span class="s1">&#39;</span>
    938. <span class="s1">&#39;Arguments: </span><span class="si">%s</span><span class="se">\n</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">msg</span><span class="p">,</span>
    939. <span class="n">record</span><span class="o">.</span><span class="n">args</span><span class="p">))</span>
    940. <span class="k">except</span> <span class="ne">RecursionError</span><span class="p">:</span> <span class="c1"># See issue 36272</span>
    941. <span class="k">raise</span>
    942. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span>
    943. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Unable to print the message and arguments&#39;</span>
    944. <span class="s1">&#39; - possible formatting error.</span><span class="se">\n</span><span class="s1">Use the&#39;</span>
    945. <span class="s1">&#39; traceback above to help find the error.</span><span class="se">\n</span><span class="s1">&#39;</span>
    946. <span class="p">)</span>
    947. <span class="k">except</span> <span class="ne">OSError</span><span class="p">:</span> <span class="c1">#pragma: no cover</span>
    948. <span class="k">pass</span> <span class="c1"># see issue 5971</span>
    949. <span class="k">finally</span><span class="p">:</span>
    950. <span class="k">del</span> <span class="n">t</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">tb</span>
    951. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    952. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">level</span><span class="p">)</span>
    953. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
    954. <span class="k">class</span> <span class="nc">StreamHandler</span><span class="p">(</span><span class="n">Handler</span><span class="p">):</span>
    955. <span class="sd">&quot;&quot;&quot;</span>
    956. <span class="sd"> A handler class which writes logging records, appropriately formatted,</span>
    957. <span class="sd"> to a stream. Note that this class does not close the stream, as</span>
    958. <span class="sd"> sys.stdout or sys.stderr may be used.</span>
    959. <span class="sd"> &quot;&quot;&quot;</span>
    960. <span class="n">terminator</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span>
    961. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    962. <span class="sd">&quot;&quot;&quot;</span>
    963. <span class="sd"> Initialize the handler.</span>
    964. <span class="sd"> If stream is not specified, sys.stderr is used.</span>
    965. <span class="sd"> &quot;&quot;&quot;</span>
    966. <span class="n">Handler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
    967. <span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    968. <span class="n">stream</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span>
    969. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
    970. <span class="k">def</span> <span class="nf">flush</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    971. <span class="sd">&quot;&quot;&quot;</span>
    972. <span class="sd"> Flushes the stream.</span>
    973. <span class="sd"> &quot;&quot;&quot;</span>
    974. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
    975. <span class="k">try</span><span class="p">:</span>
    976. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">,</span> <span class="s2">&quot;flush&quot;</span><span class="p">):</span>
    977. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    978. <span class="k">finally</span><span class="p">:</span>
    979. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
    980. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    981. <span class="sd">&quot;&quot;&quot;</span>
    982. <span class="sd"> Emit a record.</span>
    983. <span class="sd"> If a formatter is specified, it is used to format the record.</span>
    984. <span class="sd"> The record is then written to the stream with a trailing newline. If</span>
    985. <span class="sd"> exception information is present, it is formatted using</span>
    986. <span class="sd"> traceback.print_exception and appended to the stream. If the stream</span>
    987. <span class="sd"> has an &#39;encoding&#39; attribute, it is used to determine how to do the</span>
    988. <span class="sd"> output to the stream.</span>
    989. <span class="sd"> &quot;&quot;&quot;</span>
    990. <span class="k">try</span><span class="p">:</span>
    991. <span class="n">msg</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    992. <span class="n">stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span>
    993. <span class="c1"># issue 35046: merged two stream.writes into one.</span>
    994. <span class="n">stream</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">msg</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">terminator</span><span class="p">)</span>
    995. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    996. <span class="k">except</span> <span class="ne">RecursionError</span><span class="p">:</span> <span class="c1"># See issue 36272</span>
    997. <span class="k">raise</span>
    998. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span>
    999. <span class="bp">self</span><span class="o">.</span><span class="n">handleError</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    1000. <span class="k">def</span> <span class="nf">setStream</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stream</span><span class="p">):</span>
    1001. <span class="sd">&quot;&quot;&quot;</span>
    1002. <span class="sd"> Sets the StreamHandler&#39;s stream to the specified value,</span>
    1003. <span class="sd"> if it is different.</span>
    1004. <span class="sd"> Returns the old stream, if the stream was changed, or None</span>
    1005. <span class="sd"> if it wasn&#39;t.</span>
    1006. <span class="sd"> &quot;&quot;&quot;</span>
    1007. <span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
    1008. <span class="n">result</span> <span class="o">=</span> <span class="kc">None</span>
    1009. <span class="k">else</span><span class="p">:</span>
    1010. <span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span>
    1011. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
    1012. <span class="k">try</span><span class="p">:</span>
    1013. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    1014. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
    1015. <span class="k">finally</span><span class="p">:</span>
    1016. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
    1017. <span class="k">return</span> <span class="n">result</span>
    1018. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1019. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">level</span><span class="p">)</span>
    1020. <span class="n">name</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">,</span> <span class="s1">&#39;name&#39;</span><span class="p">,</span> <span class="s1">&#39;&#39;</span><span class="p">)</span>
    1021. <span class="c1"># bpo-36015: name can be an int</span>
    1022. <span class="n">name</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
    1023. <span class="k">if</span> <span class="n">name</span><span class="p">:</span>
    1024. <span class="n">name</span> <span class="o">+=</span> <span class="s1">&#39; &#39;</span>
    1025. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1">(</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
    1026. <span class="k">class</span> <span class="nc">FileHandler</span><span class="p">(</span><span class="n">StreamHandler</span><span class="p">):</span>
    1027. <span class="sd">&quot;&quot;&quot;</span>
    1028. <span class="sd"> A handler class which writes formatted logging records to disk files.</span>
    1029. <span class="sd"> &quot;&quot;&quot;</span>
    1030. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;a&#39;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">delay</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">errors</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    1031. <span class="sd">&quot;&quot;&quot;</span>
    1032. <span class="sd"> Open the specified file and use it as the stream for logging.</span>
    1033. <span class="sd"> &quot;&quot;&quot;</span>
    1034. <span class="c1"># Issue #27493: add support for Path objects to be passed in</span>
    1035. <span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">fspath</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>
    1036. <span class="c1">#keep the absolute path, otherwise derived classes which use this</span>
    1037. <span class="c1">#may come a cropper when the current directory changes</span>
    1038. <span class="bp">self</span><span class="o">.</span><span class="n">baseFilename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>
    1039. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
    1040. <span class="bp">self</span><span class="o">.</span><span class="n">encoding</span> <span class="o">=</span> <span class="n">encoding</span>
    1041. <span class="bp">self</span><span class="o">.</span><span class="n">errors</span> <span class="o">=</span> <span class="n">errors</span>
    1042. <span class="bp">self</span><span class="o">.</span><span class="n">delay</span> <span class="o">=</span> <span class="n">delay</span>
    1043. <span class="k">if</span> <span class="n">delay</span><span class="p">:</span>
    1044. <span class="c1">#We don&#39;t open the stream, but we still need to call the</span>
    1045. <span class="c1">#Handler constructor to set level, formatter, lock etc.</span>
    1046. <span class="n">Handler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
    1047. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="kc">None</span>
    1048. <span class="k">else</span><span class="p">:</span>
    1049. <span class="n">StreamHandler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_open</span><span class="p">())</span>
    1050. <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1051. <span class="sd">&quot;&quot;&quot;</span>
    1052. <span class="sd"> Closes the stream.</span>
    1053. <span class="sd"> &quot;&quot;&quot;</span>
    1054. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
    1055. <span class="k">try</span><span class="p">:</span>
    1056. <span class="k">try</span><span class="p">:</span>
    1057. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
    1058. <span class="k">try</span><span class="p">:</span>
    1059. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    1060. <span class="k">finally</span><span class="p">:</span>
    1061. <span class="n">stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span>
    1062. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="kc">None</span>
    1063. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="s2">&quot;close&quot;</span><span class="p">):</span>
    1064. <span class="n">stream</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    1065. <span class="k">finally</span><span class="p">:</span>
    1066. <span class="c1"># Issue #19523: call unconditionally to</span>
    1067. <span class="c1"># prevent a handler leak when delay is set</span>
    1068. <span class="n">StreamHandler</span><span class="o">.</span><span class="n">close</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
    1069. <span class="k">finally</span><span class="p">:</span>
    1070. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
    1071. <span class="k">def</span> <span class="nf">_open</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1072. <span class="sd">&quot;&quot;&quot;</span>
    1073. <span class="sd"> Open the current base file with the (original) mode and encoding.</span>
    1074. <span class="sd"> Return the resulting stream.</span>
    1075. <span class="sd"> &quot;&quot;&quot;</span>
    1076. <span class="k">return</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">baseFilename</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">encoding</span><span class="p">,</span>
    1077. <span class="n">errors</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">errors</span><span class="p">)</span>
    1078. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    1079. <span class="sd">&quot;&quot;&quot;</span>
    1080. <span class="sd"> Emit a record.</span>
    1081. <span class="sd"> If the stream was not opened because &#39;delay&#39; was specified in the</span>
    1082. <span class="sd"> constructor, open it before calling the superclass&#39;s emit.</span>
    1083. <span class="sd"> &quot;&quot;&quot;</span>
    1084. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    1085. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_open</span><span class="p">()</span>
    1086. <span class="n">StreamHandler</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">)</span>
    1087. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1088. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">level</span><span class="p">)</span>
    1089. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">baseFilename</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
    1090. <span class="k">class</span> <span class="nc">_StderrHandler</span><span class="p">(</span><span class="n">StreamHandler</span><span class="p">):</span>
    1091. <span class="sd">&quot;&quot;&quot;</span>
    1092. <span class="sd"> This class is like a StreamHandler using sys.stderr, but always uses</span>
    1093. <span class="sd"> whatever sys.stderr is currently set to rather than the value of</span>
    1094. <span class="sd"> sys.stderr at handler construction time.</span>
    1095. <span class="sd"> &quot;&quot;&quot;</span>
    1096. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="n">NOTSET</span><span class="p">):</span>
    1097. <span class="sd">&quot;&quot;&quot;</span>
    1098. <span class="sd"> Initialize the handler.</span>
    1099. <span class="sd"> &quot;&quot;&quot;</span>
    1100. <span class="n">Handler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
    1101. <span class="nd">@property</span>
    1102. <span class="k">def</span> <span class="nf">stream</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1103. <span class="k">return</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span>
    1104. <span class="n">_defaultLastResort</span> <span class="o">=</span> <span class="n">_StderrHandler</span><span class="p">(</span><span class="n">WARNING</span><span class="p">)</span>
    1105. <span class="n">lastResort</span> <span class="o">=</span> <span class="n">_defaultLastResort</span>
    1106. <span class="c1">#---------------------------------------------------------------------------</span>
    1107. <span class="c1"># Manager classes and functions</span>
    1108. <span class="c1">#---------------------------------------------------------------------------</span>
    1109. <span class="k">class</span> <span class="nc">PlaceHolder</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    1110. <span class="sd">&quot;&quot;&quot;</span>
    1111. <span class="sd"> PlaceHolder instances are used in the Manager logger hierarchy to take</span>
    1112. <span class="sd"> the place of nodes for which no loggers have been defined. This class is</span>
    1113. <span class="sd"> intended for internal use only and not as part of the public API.</span>
    1114. <span class="sd"> &quot;&quot;&quot;</span>
    1115. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
    1116. <span class="sd">&quot;&quot;&quot;</span>
    1117. <span class="sd"> Initialize with the specified logger being a child of this placeholder.</span>
    1118. <span class="sd"> &quot;&quot;&quot;</span>
    1119. <span class="bp">self</span><span class="o">.</span><span class="n">loggerMap</span> <span class="o">=</span> <span class="p">{</span> <span class="n">alogger</span> <span class="p">:</span> <span class="kc">None</span> <span class="p">}</span>
    1120. <span class="k">def</span> <span class="nf">append</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
    1121. <span class="sd">&quot;&quot;&quot;</span>
    1122. <span class="sd"> Add the specified logger as a child of this placeholder.</span>
    1123. <span class="sd"> &quot;&quot;&quot;</span>
    1124. <span class="k">if</span> <span class="n">alogger</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerMap</span><span class="p">:</span>
    1125. <span class="bp">self</span><span class="o">.</span><span class="n">loggerMap</span><span class="p">[</span><span class="n">alogger</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
    1126. <span class="c1">#</span>
    1127. <span class="c1"># Determine which class to use when instantiating loggers.</span>
    1128. <span class="c1">#</span>
    1129. <span class="k">def</span> <span class="nf">setLoggerClass</span><span class="p">(</span><span class="n">klass</span><span class="p">):</span>
    1130. <span class="sd">&quot;&quot;&quot;</span>
    1131. <span class="sd"> Set the class to be used when instantiating a logger. The class should</span>
    1132. <span class="sd"> define __init__() such that only a name argument is required, and the</span>
    1133. <span class="sd"> __init__() should call Logger.__init__()</span>
    1134. <span class="sd"> &quot;&quot;&quot;</span>
    1135. <span class="k">if</span> <span class="n">klass</span> <span class="o">!=</span> <span class="n">Logger</span><span class="p">:</span>
    1136. <span class="k">if</span> <span class="ow">not</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">klass</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
    1137. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;logger not derived from logging.Logger: &quot;</span>
    1138. <span class="o">+</span> <span class="n">klass</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
    1139. <span class="k">global</span> <span class="n">_loggerClass</span>
    1140. <span class="n">_loggerClass</span> <span class="o">=</span> <span class="n">klass</span>
    1141. <span class="k">def</span> <span class="nf">getLoggerClass</span><span class="p">():</span>
    1142. <span class="sd">&quot;&quot;&quot;</span>
    1143. <span class="sd"> Return the class to be used when instantiating a logger.</span>
    1144. <span class="sd"> &quot;&quot;&quot;</span>
    1145. <span class="k">return</span> <span class="n">_loggerClass</span>
    1146. <span class="k">class</span> <span class="nc">Manager</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    1147. <span class="sd">&quot;&quot;&quot;</span>
    1148. <span class="sd"> There is [under normal circumstances] just one Manager instance, which</span>
    1149. <span class="sd"> holds the hierarchy of loggers.</span>
    1150. <span class="sd"> &quot;&quot;&quot;</span>
    1151. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rootnode</span><span class="p">):</span>
    1152. <span class="sd">&quot;&quot;&quot;</span>
    1153. <span class="sd"> Initialize the manager with the root node of the logger hierarchy.</span>
    1154. <span class="sd"> &quot;&quot;&quot;</span>
    1155. <span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">=</span> <span class="n">rootnode</span>
    1156. <span class="bp">self</span><span class="o">.</span><span class="n">disable</span> <span class="o">=</span> <span class="mi">0</span>
    1157. <span class="bp">self</span><span class="o">.</span><span class="n">emittedNoHandlerWarning</span> <span class="o">=</span> <span class="kc">False</span>
    1158. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span> <span class="o">=</span> <span class="p">{}</span>
    1159. <span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="o">=</span> <span class="kc">None</span>
    1160. <span class="bp">self</span><span class="o">.</span><span class="n">logRecordFactory</span> <span class="o">=</span> <span class="kc">None</span>
    1161. <span class="nd">@property</span>
    1162. <span class="k">def</span> <span class="nf">disable</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1163. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_disable</span>
    1164. <span class="nd">@disable</span><span class="o">.</span><span class="n">setter</span>
    1165. <span class="k">def</span> <span class="nf">disable</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
    1166. <span class="bp">self</span><span class="o">.</span><span class="n">_disable</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
    1167. <span class="k">def</span> <span class="nf">getLogger</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
    1168. <span class="sd">&quot;&quot;&quot;</span>
    1169. <span class="sd"> Get a logger with the specified name (channel name), creating it</span>
    1170. <span class="sd"> if it doesn&#39;t yet exist. This name is a dot-separated hierarchical</span>
    1171. <span class="sd"> name, such as &quot;a&quot;, &quot;a.b&quot;, &quot;a.b.c&quot; or similar.</span>
    1172. <span class="sd"> If a PlaceHolder existed for the specified name [i.e. the logger</span>
    1173. <span class="sd"> didn&#39;t exist but a child of it did], replace it with the created</span>
    1174. <span class="sd"> logger and fix up the parent/child references which pointed to the</span>
    1175. <span class="sd"> placeholder to now point to the logger.</span>
    1176. <span class="sd"> &quot;&quot;&quot;</span>
    1177. <span class="n">rv</span> <span class="o">=</span> <span class="kc">None</span>
    1178. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
    1179. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s1">&#39;A logger name must be a string&#39;</span><span class="p">)</span>
    1180. <span class="n">_acquireLock</span><span class="p">()</span>
    1181. <span class="k">try</span><span class="p">:</span>
    1182. <span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">:</span>
    1183. <span class="n">rv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
    1184. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">rv</span><span class="p">,</span> <span class="n">PlaceHolder</span><span class="p">):</span>
    1185. <span class="n">ph</span> <span class="o">=</span> <span class="n">rv</span>
    1186. <span class="n">rv</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="ow">or</span> <span class="n">_loggerClass</span><span class="p">)(</span><span class="n">name</span><span class="p">)</span>
    1187. <span class="n">rv</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="bp">self</span>
    1188. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">rv</span>
    1189. <span class="bp">self</span><span class="o">.</span><span class="n">_fixupChildren</span><span class="p">(</span><span class="n">ph</span><span class="p">,</span> <span class="n">rv</span><span class="p">)</span>
    1190. <span class="bp">self</span><span class="o">.</span><span class="n">_fixupParents</span><span class="p">(</span><span class="n">rv</span><span class="p">)</span>
    1191. <span class="k">else</span><span class="p">:</span>
    1192. <span class="n">rv</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="ow">or</span> <span class="n">_loggerClass</span><span class="p">)(</span><span class="n">name</span><span class="p">)</span>
    1193. <span class="n">rv</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="bp">self</span>
    1194. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">rv</span>
    1195. <span class="bp">self</span><span class="o">.</span><span class="n">_fixupParents</span><span class="p">(</span><span class="n">rv</span><span class="p">)</span>
    1196. <span class="k">finally</span><span class="p">:</span>
    1197. <span class="n">_releaseLock</span><span class="p">()</span>
    1198. <span class="k">return</span> <span class="n">rv</span>
    1199. <span class="k">def</span> <span class="nf">setLoggerClass</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">klass</span><span class="p">):</span>
    1200. <span class="sd">&quot;&quot;&quot;</span>
    1201. <span class="sd"> Set the class to be used when instantiating a logger with this Manager.</span>
    1202. <span class="sd"> &quot;&quot;&quot;</span>
    1203. <span class="k">if</span> <span class="n">klass</span> <span class="o">!=</span> <span class="n">Logger</span><span class="p">:</span>
    1204. <span class="k">if</span> <span class="ow">not</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">klass</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
    1205. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;logger not derived from logging.Logger: &quot;</span>
    1206. <span class="o">+</span> <span class="n">klass</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
    1207. <span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="o">=</span> <span class="n">klass</span>
    1208. <span class="k">def</span> <span class="nf">setLogRecordFactory</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">factory</span><span class="p">):</span>
    1209. <span class="sd">&quot;&quot;&quot;</span>
    1210. <span class="sd"> Set the factory to be used when instantiating a log record with this</span>
    1211. <span class="sd"> Manager.</span>
    1212. <span class="sd"> &quot;&quot;&quot;</span>
    1213. <span class="bp">self</span><span class="o">.</span><span class="n">logRecordFactory</span> <span class="o">=</span> <span class="n">factory</span>
    1214. <span class="k">def</span> <span class="nf">_fixupParents</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
    1215. <span class="sd">&quot;&quot;&quot;</span>
    1216. <span class="sd"> Ensure that there are either loggers or placeholders all the way</span>
    1217. <span class="sd"> from the specified logger to the root of the logger hierarchy.</span>
    1218. <span class="sd"> &quot;&quot;&quot;</span>
    1219. <span class="n">name</span> <span class="o">=</span> <span class="n">alogger</span><span class="o">.</span><span class="n">name</span>
    1220. <span class="n">i</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">rfind</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)</span>
    1221. <span class="n">rv</span> <span class="o">=</span> <span class="kc">None</span>
    1222. <span class="k">while</span> <span class="p">(</span><span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">rv</span><span class="p">:</span>
    1223. <span class="n">substr</span> <span class="o">=</span> <span class="n">name</span><span class="p">[:</span><span class="n">i</span><span class="p">]</span>
    1224. <span class="k">if</span> <span class="n">substr</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">:</span>
    1225. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">substr</span><span class="p">]</span> <span class="o">=</span> <span class="n">PlaceHolder</span><span class="p">(</span><span class="n">alogger</span><span class="p">)</span>
    1226. <span class="k">else</span><span class="p">:</span>
    1227. <span class="n">obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">substr</span><span class="p">]</span>
    1228. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
    1229. <span class="n">rv</span> <span class="o">=</span> <span class="n">obj</span>
    1230. <span class="k">else</span><span class="p">:</span>
    1231. <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">PlaceHolder</span><span class="p">)</span>
    1232. <span class="n">obj</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">alogger</span><span class="p">)</span>
    1233. <span class="n">i</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">rfind</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
    1234. <span class="k">if</span> <span class="ow">not</span> <span class="n">rv</span><span class="p">:</span>
    1235. <span class="n">rv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">root</span>
    1236. <span class="n">alogger</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">rv</span>
    1237. <span class="k">def</span> <span class="nf">_fixupChildren</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ph</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
    1238. <span class="sd">&quot;&quot;&quot;</span>
    1239. <span class="sd"> Ensure that children of the placeholder ph are connected to the</span>
    1240. <span class="sd"> specified logger.</span>
    1241. <span class="sd"> &quot;&quot;&quot;</span>
    1242. <span class="n">name</span> <span class="o">=</span> <span class="n">alogger</span><span class="o">.</span><span class="n">name</span>
    1243. <span class="n">namelen</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
    1244. <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">ph</span><span class="o">.</span><span class="n">loggerMap</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    1245. <span class="c1">#The if means ... if not c.parent.name.startswith(nm)</span>
    1246. <span class="k">if</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">name</span><span class="p">[:</span><span class="n">namelen</span><span class="p">]</span> <span class="o">!=</span> <span class="n">name</span><span class="p">:</span>
    1247. <span class="n">alogger</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span>
    1248. <span class="n">c</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">alogger</span>
    1249. <span class="k">def</span> <span class="nf">_clear_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1250. <span class="sd">&quot;&quot;&quot;</span>
    1251. <span class="sd"> Clear the cache for all loggers in loggerDict</span>
    1252. <span class="sd"> Called when level changes are made</span>
    1253. <span class="sd"> &quot;&quot;&quot;</span>
    1254. <span class="n">_acquireLock</span><span class="p">()</span>
    1255. <span class="k">for</span> <span class="n">logger</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
    1256. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">logger</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
    1257. <span class="n">logger</span><span class="o">.</span><span class="n">_cache</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
    1258. <span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">_cache</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
    1259. <span class="n">_releaseLock</span><span class="p">()</span>
    1260. <span class="c1">#---------------------------------------------------------------------------</span>
    1261. <span class="c1"># Logger classes and functions</span>
    1262. <span class="c1">#---------------------------------------------------------------------------</span>
    1263. <span class="k">class</span> <span class="nc">Logger</span><span class="p">(</span><span class="n">Filterer</span><span class="p">):</span>
    1264. <span class="sd">&quot;&quot;&quot;</span>
    1265. <span class="sd"> Instances of the Logger class represent a single logging channel. A</span>
    1266. <span class="sd"> &quot;logging channel&quot; indicates an area of an application. Exactly how an</span>
    1267. <span class="sd"> &quot;area&quot; is defined is up to the application developer. Since an</span>
    1268. <span class="sd"> application can have any number of areas, logging channels are identified</span>
    1269. <span class="sd"> by a unique string. Application areas can be nested (e.g. an area</span>
    1270. <span class="sd"> of &quot;input processing&quot; might include sub-areas &quot;read CSV files&quot;, &quot;read</span>
    1271. <span class="sd"> XLS files&quot; and &quot;read Gnumeric files&quot;). To cater for this natural nesting,</span>
    1272. <span class="sd"> channel names are organized into a namespace hierarchy where levels are</span>
    1273. <span class="sd"> separated by periods, much like the Java or Python package namespace. So</span>
    1274. <span class="sd"> in the instance given above, channel names might be &quot;input&quot; for the upper</span>
    1275. <span class="sd"> level, and &quot;input.csv&quot;, &quot;input.xls&quot; and &quot;input.gnu&quot; for the sub-levels.</span>
    1276. <span class="sd"> There is no arbitrary limit to the depth of nesting.</span>
    1277. <span class="sd"> &quot;&quot;&quot;</span>
    1278. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="n">NOTSET</span><span class="p">):</span>
    1279. <span class="sd">&quot;&quot;&quot;</span>
    1280. <span class="sd"> Initialize the logger with a name and an optional level.</span>
    1281. <span class="sd"> &quot;&quot;&quot;</span>
    1282. <span class="n">Filterer</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
    1283. <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
    1284. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    1285. <span class="bp">self</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="kc">None</span>
    1286. <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span> <span class="o">=</span> <span class="kc">True</span>
    1287. <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span> <span class="o">=</span> <span class="p">[]</span>
    1288. <span class="bp">self</span><span class="o">.</span><span class="n">disabled</span> <span class="o">=</span> <span class="kc">False</span>
    1289. <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span> <span class="o">=</span> <span class="p">{}</span>
    1290. <span class="k">def</span> <span class="nf">setLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
    1291. <span class="sd">&quot;&quot;&quot;</span>
    1292. <span class="sd"> Set the logging level of this logger. level must be an int or a str.</span>
    1293. <span class="sd"> &quot;&quot;&quot;</span>
    1294. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    1295. <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">_clear_cache</span><span class="p">()</span>
    1296. <span class="k">def</span> <span class="nf">debug</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1297. <span class="sd">&quot;&quot;&quot;</span>
    1298. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;DEBUG&#39;.</span>
    1299. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
    1300. <span class="sd"> a true value, e.g.</span>
    1301. <span class="sd"> logger.debug(&quot;Houston, we have a %s&quot;, &quot;thorny problem&quot;, exc_info=1)</span>
    1302. <span class="sd"> &quot;&quot;&quot;</span>
    1303. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">DEBUG</span><span class="p">):</span>
    1304. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">DEBUG</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1305. <span class="k">def</span> <span class="nf">info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1306. <span class="sd">&quot;&quot;&quot;</span>
    1307. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;INFO&#39;.</span>
    1308. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
    1309. <span class="sd"> a true value, e.g.</span>
    1310. <span class="sd"> logger.info(&quot;Houston, we have a %s&quot;, &quot;interesting problem&quot;, exc_info=1)</span>
    1311. <span class="sd"> &quot;&quot;&quot;</span>
    1312. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">INFO</span><span class="p">):</span>
    1313. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">INFO</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1314. <span class="k">def</span> <span class="nf">warning</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1315. <span class="sd">&quot;&quot;&quot;</span>
    1316. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;WARNING&#39;.</span>
    1317. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
    1318. <span class="sd"> a true value, e.g.</span>
    1319. <span class="sd"> logger.warning(&quot;Houston, we have a %s&quot;, &quot;bit of a problem&quot;, exc_info=1)</span>
    1320. <span class="sd"> &quot;&quot;&quot;</span>
    1321. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">WARNING</span><span class="p">):</span>
    1322. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">WARNING</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1323. <span class="k">def</span> <span class="nf">warn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1324. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The &#39;warn&#39; method is deprecated, &quot;</span>
    1325. <span class="s2">&quot;use &#39;warning&#39; instead&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    1326. <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1327. <span class="k">def</span> <span class="nf">error</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1328. <span class="sd">&quot;&quot;&quot;</span>
    1329. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;ERROR&#39;.</span>
    1330. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
    1331. <span class="sd"> a true value, e.g.</span>
    1332. <span class="sd"> logger.error(&quot;Houston, we have a %s&quot;, &quot;major problem&quot;, exc_info=1)</span>
    1333. <span class="sd"> &quot;&quot;&quot;</span>
    1334. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">ERROR</span><span class="p">):</span>
    1335. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">ERROR</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1336. <span class="k">def</span> <span class="nf">exception</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1337. <span class="sd">&quot;&quot;&quot;</span>
    1338. <span class="sd"> Convenience method for logging an ERROR with exception information.</span>
    1339. <span class="sd"> &quot;&quot;&quot;</span>
    1340. <span class="bp">self</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1341. <span class="k">def</span> <span class="nf">critical</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1342. <span class="sd">&quot;&quot;&quot;</span>
    1343. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;CRITICAL&#39;.</span>
    1344. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
    1345. <span class="sd"> a true value, e.g.</span>
    1346. <span class="sd"> logger.critical(&quot;Houston, we have a %s&quot;, &quot;major disaster&quot;, exc_info=1)</span>
    1347. <span class="sd"> &quot;&quot;&quot;</span>
    1348. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">CRITICAL</span><span class="p">):</span>
    1349. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">CRITICAL</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1350. <span class="n">fatal</span> <span class="o">=</span> <span class="n">critical</span>
    1351. <span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1352. <span class="sd">&quot;&quot;&quot;</span>
    1353. <span class="sd"> Log &#39;msg % args&#39; with the integer severity &#39;level&#39;.</span>
    1354. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
    1355. <span class="sd"> a true value, e.g.</span>
    1356. <span class="sd"> logger.log(level, &quot;We have a %s&quot;, &quot;mysterious problem&quot;, exc_info=1)</span>
    1357. <span class="sd"> &quot;&quot;&quot;</span>
    1358. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
    1359. <span class="k">if</span> <span class="n">raiseExceptions</span><span class="p">:</span>
    1360. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;level must be an integer&quot;</span><span class="p">)</span>
    1361. <span class="k">else</span><span class="p">:</span>
    1362. <span class="k">return</span>
    1363. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
    1364. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1365. <span class="k">def</span> <span class="nf">findCaller</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stack_info</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">stacklevel</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
    1366. <span class="sd">&quot;&quot;&quot;</span>
    1367. <span class="sd"> Find the stack frame of the caller so that we can note the source</span>
    1368. <span class="sd"> file name, line number and function name.</span>
    1369. <span class="sd"> &quot;&quot;&quot;</span>
    1370. <span class="n">f</span> <span class="o">=</span> <span class="n">currentframe</span><span class="p">()</span>
    1371. <span class="c1">#On some versions of IronPython, currentframe() returns None if</span>
    1372. <span class="c1">#IronPython isn&#39;t run with -X:Frames.</span>
    1373. <span class="k">if</span> <span class="n">f</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    1374. <span class="n">f</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_back</span>
    1375. <span class="n">orig_f</span> <span class="o">=</span> <span class="n">f</span>
    1376. <span class="k">while</span> <span class="n">f</span> <span class="ow">and</span> <span class="n">stacklevel</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    1377. <span class="n">f</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_back</span>
    1378. <span class="n">stacklevel</span> <span class="o">-=</span> <span class="mi">1</span>
    1379. <span class="k">if</span> <span class="ow">not</span> <span class="n">f</span><span class="p">:</span>
    1380. <span class="n">f</span> <span class="o">=</span> <span class="n">orig_f</span>
    1381. <span class="n">rv</span> <span class="o">=</span> <span class="s2">&quot;(unknown file)&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;(unknown function)&quot;</span><span class="p">,</span> <span class="kc">None</span>
    1382. <span class="k">while</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="s2">&quot;f_code&quot;</span><span class="p">):</span>
    1383. <span class="n">co</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_code</span>
    1384. <span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">normcase</span><span class="p">(</span><span class="n">co</span><span class="o">.</span><span class="n">co_filename</span><span class="p">)</span>
    1385. <span class="k">if</span> <span class="n">filename</span> <span class="o">==</span> <span class="n">_srcfile</span><span class="p">:</span>
    1386. <span class="n">f</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_back</span>
    1387. <span class="k">continue</span>
    1388. <span class="n">sinfo</span> <span class="o">=</span> <span class="kc">None</span>
    1389. <span class="k">if</span> <span class="n">stack_info</span><span class="p">:</span>
    1390. <span class="n">sio</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">StringIO</span><span class="p">()</span>
    1391. <span class="n">sio</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Stack (most recent call last):</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    1392. <span class="n">traceback</span><span class="o">.</span><span class="n">print_stack</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">sio</span><span class="p">)</span>
    1393. <span class="n">sinfo</span> <span class="o">=</span> <span class="n">sio</span><span class="o">.</span><span class="n">getvalue</span><span class="p">()</span>
    1394. <span class="k">if</span> <span class="n">sinfo</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">:</span>
    1395. <span class="n">sinfo</span> <span class="o">=</span> <span class="n">sinfo</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    1396. <span class="n">sio</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    1397. <span class="n">rv</span> <span class="o">=</span> <span class="p">(</span><span class="n">co</span><span class="o">.</span><span class="n">co_filename</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">f_lineno</span><span class="p">,</span> <span class="n">co</span><span class="o">.</span><span class="n">co_name</span><span class="p">,</span> <span class="n">sinfo</span><span class="p">)</span>
    1398. <span class="k">break</span>
    1399. <span class="k">return</span> <span class="n">rv</span>
    1400. <span class="k">def</span> <span class="nf">makeRecord</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="p">,</span>
    1401. <span class="n">func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">extra</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sinfo</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    1402. <span class="sd">&quot;&quot;&quot;</span>
    1403. <span class="sd"> A factory method which can be overridden in subclasses to create</span>
    1404. <span class="sd"> specialized LogRecords.</span>
    1405. <span class="sd"> &quot;&quot;&quot;</span>
    1406. <span class="n">rv</span> <span class="o">=</span> <span class="n">_logRecordFactory</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="p">,</span> <span class="n">func</span><span class="p">,</span>
    1407. <span class="n">sinfo</span><span class="p">)</span>
    1408. <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    1409. <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">extra</span><span class="p">:</span>
    1410. <span class="k">if</span> <span class="p">(</span><span class="n">key</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;message&quot;</span><span class="p">,</span> <span class="s2">&quot;asctime&quot;</span><span class="p">])</span> <span class="ow">or</span> <span class="p">(</span><span class="n">key</span> <span class="ow">in</span> <span class="n">rv</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">):</span>
    1411. <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s2">&quot;Attempt to overwrite </span><span class="si">%r</span><span class="s2"> in LogRecord&quot;</span> <span class="o">%</span> <span class="n">key</span><span class="p">)</span>
    1412. <span class="n">rv</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">extra</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
    1413. <span class="k">return</span> <span class="n">rv</span>
    1414. <span class="k">def</span> <span class="nf">_log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">extra</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">stack_info</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    1415. <span class="n">stacklevel</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
    1416. <span class="sd">&quot;&quot;&quot;</span>
    1417. <span class="sd"> Low-level logging routine which creates a LogRecord and then calls</span>
    1418. <span class="sd"> all the handlers of this logger to handle the record.</span>
    1419. <span class="sd"> &quot;&quot;&quot;</span>
    1420. <span class="n">sinfo</span> <span class="o">=</span> <span class="kc">None</span>
    1421. <span class="k">if</span> <span class="n">_srcfile</span><span class="p">:</span>
    1422. <span class="c1">#IronPython doesn&#39;t track Python frames, so findCaller raises an</span>
    1423. <span class="c1">#exception on some versions of IronPython. We trap it here so that</span>
    1424. <span class="c1">#IronPython can use logging.</span>
    1425. <span class="k">try</span><span class="p">:</span>
    1426. <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">func</span><span class="p">,</span> <span class="n">sinfo</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">findCaller</span><span class="p">(</span><span class="n">stack_info</span><span class="p">,</span> <span class="n">stacklevel</span><span class="p">)</span>
    1427. <span class="k">except</span> <span class="ne">ValueError</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
    1428. <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="s2">&quot;(unknown file)&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;(unknown function)&quot;</span>
    1429. <span class="k">else</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
    1430. <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="s2">&quot;(unknown file)&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;(unknown function)&quot;</span>
    1431. <span class="k">if</span> <span class="n">exc_info</span><span class="p">:</span>
    1432. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">exc_info</span><span class="p">,</span> <span class="ne">BaseException</span><span class="p">):</span>
    1433. <span class="n">exc_info</span> <span class="o">=</span> <span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">exc_info</span><span class="p">),</span> <span class="n">exc_info</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">.</span><span class="n">__traceback__</span><span class="p">)</span>
    1434. <span class="k">elif</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">exc_info</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
    1435. <span class="n">exc_info</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()</span>
    1436. <span class="n">record</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">makeRecord</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span>
    1437. <span class="n">exc_info</span><span class="p">,</span> <span class="n">func</span><span class="p">,</span> <span class="n">extra</span><span class="p">,</span> <span class="n">sinfo</span><span class="p">)</span>
    1438. <span class="bp">self</span><span class="o">.</span><span class="n">handle</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    1439. <span class="k">def</span> <span class="nf">handle</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    1440. <span class="sd">&quot;&quot;&quot;</span>
    1441. <span class="sd"> Call the handlers for the specified record.</span>
    1442. <span class="sd"> This method is used for unpickled records received from a socket, as</span>
    1443. <span class="sd"> well as those created locally. Logger-level filtering is applied.</span>
    1444. <span class="sd"> &quot;&quot;&quot;</span>
    1445. <span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">disabled</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">record</span><span class="p">):</span>
    1446. <span class="bp">self</span><span class="o">.</span><span class="n">callHandlers</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    1447. <span class="k">def</span> <span class="nf">addHandler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hdlr</span><span class="p">):</span>
    1448. <span class="sd">&quot;&quot;&quot;</span>
    1449. <span class="sd"> Add the specified handler to this logger.</span>
    1450. <span class="sd"> &quot;&quot;&quot;</span>
    1451. <span class="n">_acquireLock</span><span class="p">()</span>
    1452. <span class="k">try</span><span class="p">:</span>
    1453. <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">hdlr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="p">):</span>
    1454. <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">hdlr</span><span class="p">)</span>
    1455. <span class="k">finally</span><span class="p">:</span>
    1456. <span class="n">_releaseLock</span><span class="p">()</span>
    1457. <span class="k">def</span> <span class="nf">removeHandler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hdlr</span><span class="p">):</span>
    1458. <span class="sd">&quot;&quot;&quot;</span>
    1459. <span class="sd"> Remove the specified handler from this logger.</span>
    1460. <span class="sd"> &quot;&quot;&quot;</span>
    1461. <span class="n">_acquireLock</span><span class="p">()</span>
    1462. <span class="k">try</span><span class="p">:</span>
    1463. <span class="k">if</span> <span class="n">hdlr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
    1464. <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">hdlr</span><span class="p">)</span>
    1465. <span class="k">finally</span><span class="p">:</span>
    1466. <span class="n">_releaseLock</span><span class="p">()</span>
    1467. <span class="k">def</span> <span class="nf">hasHandlers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1468. <span class="sd">&quot;&quot;&quot;</span>
    1469. <span class="sd"> See if this logger has any handlers configured.</span>
    1470. <span class="sd"> Loop through all handlers for this logger and its parents in the</span>
    1471. <span class="sd"> logger hierarchy. Return True if a handler was found, else False.</span>
    1472. <span class="sd"> Stop searching up the hierarchy whenever a logger with the &quot;propagate&quot;</span>
    1473. <span class="sd"> attribute set to zero is found - that will be the last logger which</span>
    1474. <span class="sd"> is checked for the existence of handlers.</span>
    1475. <span class="sd"> &quot;&quot;&quot;</span>
    1476. <span class="n">c</span> <span class="o">=</span> <span class="bp">self</span>
    1477. <span class="n">rv</span> <span class="o">=</span> <span class="kc">False</span>
    1478. <span class="k">while</span> <span class="n">c</span><span class="p">:</span>
    1479. <span class="k">if</span> <span class="n">c</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
    1480. <span class="n">rv</span> <span class="o">=</span> <span class="kc">True</span>
    1481. <span class="k">break</span>
    1482. <span class="k">if</span> <span class="ow">not</span> <span class="n">c</span><span class="o">.</span><span class="n">propagate</span><span class="p">:</span>
    1483. <span class="k">break</span>
    1484. <span class="k">else</span><span class="p">:</span>
    1485. <span class="n">c</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span>
    1486. <span class="k">return</span> <span class="n">rv</span>
    1487. <span class="k">def</span> <span class="nf">callHandlers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    1488. <span class="sd">&quot;&quot;&quot;</span>
    1489. <span class="sd"> Pass a record to all relevant handlers.</span>
    1490. <span class="sd"> Loop through all handlers for this logger and its parents in the</span>
    1491. <span class="sd"> logger hierarchy. If no handler was found, output a one-off error</span>
    1492. <span class="sd"> message to sys.stderr. Stop searching up the hierarchy whenever a</span>
    1493. <span class="sd"> logger with the &quot;propagate&quot; attribute set to zero is found - that</span>
    1494. <span class="sd"> will be the last logger whose handlers are called.</span>
    1495. <span class="sd"> &quot;&quot;&quot;</span>
    1496. <span class="n">c</span> <span class="o">=</span> <span class="bp">self</span>
    1497. <span class="n">found</span> <span class="o">=</span> <span class="mi">0</span>
    1498. <span class="k">while</span> <span class="n">c</span><span class="p">:</span>
    1499. <span class="k">for</span> <span class="n">hdlr</span> <span class="ow">in</span> <span class="n">c</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
    1500. <span class="n">found</span> <span class="o">=</span> <span class="n">found</span> <span class="o">+</span> <span class="mi">1</span>
    1501. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">levelno</span> <span class="o">&gt;=</span> <span class="n">hdlr</span><span class="o">.</span><span class="n">level</span><span class="p">:</span>
    1502. <span class="n">hdlr</span><span class="o">.</span><span class="n">handle</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    1503. <span class="k">if</span> <span class="ow">not</span> <span class="n">c</span><span class="o">.</span><span class="n">propagate</span><span class="p">:</span>
    1504. <span class="n">c</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1">#break out</span>
    1505. <span class="k">else</span><span class="p">:</span>
    1506. <span class="n">c</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span>
    1507. <span class="k">if</span> <span class="p">(</span><span class="n">found</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span>
    1508. <span class="k">if</span> <span class="n">lastResort</span><span class="p">:</span>
    1509. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">levelno</span> <span class="o">&gt;=</span> <span class="n">lastResort</span><span class="o">.</span><span class="n">level</span><span class="p">:</span>
    1510. <span class="n">lastResort</span><span class="o">.</span><span class="n">handle</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
    1511. <span class="k">elif</span> <span class="n">raiseExceptions</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">emittedNoHandlerWarning</span><span class="p">:</span>
    1512. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;No handlers could be found for logger&quot;</span>
    1513. <span class="s2">&quot; </span><span class="se">\&quot;</span><span class="si">%s</span><span class="se">\&quot;\n</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
    1514. <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">emittedNoHandlerWarning</span> <span class="o">=</span> <span class="kc">True</span>
    1515. <span class="k">def</span> <span class="nf">getEffectiveLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1516. <span class="sd">&quot;&quot;&quot;</span>
    1517. <span class="sd"> Get the effective level for this logger.</span>
    1518. <span class="sd"> Loop through this logger and its parents in the logger hierarchy,</span>
    1519. <span class="sd"> looking for a non-zero logging level. Return the first one found.</span>
    1520. <span class="sd"> &quot;&quot;&quot;</span>
    1521. <span class="n">logger</span> <span class="o">=</span> <span class="bp">self</span>
    1522. <span class="k">while</span> <span class="n">logger</span><span class="p">:</span>
    1523. <span class="k">if</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span><span class="p">:</span>
    1524. <span class="k">return</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span>
    1525. <span class="n">logger</span> <span class="o">=</span> <span class="n">logger</span><span class="o">.</span><span class="n">parent</span>
    1526. <span class="k">return</span> <span class="n">NOTSET</span>
    1527. <span class="k">def</span> <span class="nf">isEnabledFor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
    1528. <span class="sd">&quot;&quot;&quot;</span>
    1529. <span class="sd"> Is this logger enabled for level &#39;level&#39;?</span>
    1530. <span class="sd"> &quot;&quot;&quot;</span>
    1531. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">disabled</span><span class="p">:</span>
    1532. <span class="k">return</span> <span class="kc">False</span>
    1533. <span class="k">try</span><span class="p">:</span>
    1534. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">level</span><span class="p">]</span>
    1535. <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
    1536. <span class="n">_acquireLock</span><span class="p">()</span>
    1537. <span class="k">try</span><span class="p">:</span>
    1538. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">disable</span> <span class="o">&gt;=</span> <span class="n">level</span><span class="p">:</span>
    1539. <span class="n">is_enabled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">level</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span>
    1540. <span class="k">else</span><span class="p">:</span>
    1541. <span class="n">is_enabled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">level</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span>
    1542. <span class="n">level</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">()</span>
    1543. <span class="p">)</span>
    1544. <span class="k">finally</span><span class="p">:</span>
    1545. <span class="n">_releaseLock</span><span class="p">()</span>
    1546. <span class="k">return</span> <span class="n">is_enabled</span>
    1547. <span class="k">def</span> <span class="nf">getChild</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span>
    1548. <span class="sd">&quot;&quot;&quot;</span>
    1549. <span class="sd"> Get a logger which is a descendant to this one.</span>
    1550. <span class="sd"> This is a convenience method, such that</span>
    1551. <span class="sd"> logging.getLogger(&#39;abc&#39;).getChild(&#39;def.ghi&#39;)</span>
    1552. <span class="sd"> is the same as</span>
    1553. <span class="sd"> logging.getLogger(&#39;abc.def.ghi&#39;)</span>
    1554. <span class="sd"> It&#39;s useful, for example, when the parent logger is named using</span>
    1555. <span class="sd"> __name__ rather than a literal string.</span>
    1556. <span class="sd"> &quot;&quot;&quot;</span>
    1557. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">:</span>
    1558. <span class="n">suffix</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">suffix</span><span class="p">))</span>
    1559. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">suffix</span><span class="p">)</span>
    1560. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1561. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">())</span>
    1562. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
    1563. <span class="k">def</span> <span class="nf">__reduce__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1564. <span class="c1"># In general, only the root logger will not be accessible via its name.</span>
    1565. <span class="c1"># However, the root logger&#39;s class has its own __reduce__ method.</span>
    1566. <span class="k">if</span> <span class="n">getLogger</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">:</span>
    1567. <span class="kn">import</span> <span class="nn">pickle</span>
    1568. <span class="k">raise</span> <span class="n">pickle</span><span class="o">.</span><span class="n">PicklingError</span><span class="p">(</span><span class="s1">&#39;logger cannot be pickled&#39;</span><span class="p">)</span>
    1569. <span class="k">return</span> <span class="n">getLogger</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,)</span>
    1570. <span class="k">class</span> <span class="nc">RootLogger</span><span class="p">(</span><span class="n">Logger</span><span class="p">):</span>
    1571. <span class="sd">&quot;&quot;&quot;</span>
    1572. <span class="sd"> A root logger is not that different to any other logger, except that</span>
    1573. <span class="sd"> it must have a logging level and there is only one instance of it in</span>
    1574. <span class="sd"> the hierarchy.</span>
    1575. <span class="sd"> &quot;&quot;&quot;</span>
    1576. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
    1577. <span class="sd">&quot;&quot;&quot;</span>
    1578. <span class="sd"> Initialize the logger with the name &quot;root&quot;.</span>
    1579. <span class="sd"> &quot;&quot;&quot;</span>
    1580. <span class="n">Logger</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;root&quot;</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
    1581. <span class="k">def</span> <span class="nf">__reduce__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1582. <span class="k">return</span> <span class="n">getLogger</span><span class="p">,</span> <span class="p">()</span>
    1583. <span class="n">_loggerClass</span> <span class="o">=</span> <span class="n">Logger</span>
    1584. <span class="k">class</span> <span class="nc">LoggerAdapter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    1585. <span class="sd">&quot;&quot;&quot;</span>
    1586. <span class="sd"> An adapter for loggers which makes it easier to specify contextual</span>
    1587. <span class="sd"> information in logging output.</span>
    1588. <span class="sd"> &quot;&quot;&quot;</span>
    1589. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">extra</span><span class="p">):</span>
    1590. <span class="sd">&quot;&quot;&quot;</span>
    1591. <span class="sd"> Initialize the adapter with a logger and a dict-like object which</span>
    1592. <span class="sd"> provides contextual information. This constructor signature allows</span>
    1593. <span class="sd"> easy stacking of LoggerAdapters, if so desired.</span>
    1594. <span class="sd"> You can effectively pass keyword arguments as shown in the</span>
    1595. <span class="sd"> following example:</span>
    1596. <span class="sd"> adapter = LoggerAdapter(someLogger, dict(p1=v1, p2=&quot;v2&quot;))</span>
    1597. <span class="sd"> &quot;&quot;&quot;</span>
    1598. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span> <span class="o">=</span> <span class="n">logger</span>
    1599. <span class="bp">self</span><span class="o">.</span><span class="n">extra</span> <span class="o">=</span> <span class="n">extra</span>
    1600. <span class="k">def</span> <span class="nf">process</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">):</span>
    1601. <span class="sd">&quot;&quot;&quot;</span>
    1602. <span class="sd"> Process the logging message and keyword arguments passed in to</span>
    1603. <span class="sd"> a logging call to insert contextual information. You can either</span>
    1604. <span class="sd"> manipulate the message itself, the keyword args or both. Return</span>
    1605. <span class="sd"> the message and kwargs modified (or not) to suit your needs.</span>
    1606. <span class="sd"> Normally, you&#39;ll only need to override this one method in a</span>
    1607. <span class="sd"> LoggerAdapter subclass for your specific needs.</span>
    1608. <span class="sd"> &quot;&quot;&quot;</span>
    1609. <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;extra&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">extra</span>
    1610. <span class="k">return</span> <span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span>
    1611. <span class="c1">#</span>
    1612. <span class="c1"># Boilerplate convenience methods</span>
    1613. <span class="c1">#</span>
    1614. <span class="k">def</span> <span class="nf">debug</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1615. <span class="sd">&quot;&quot;&quot;</span>
    1616. <span class="sd"> Delegate a debug call to the underlying logger.</span>
    1617. <span class="sd"> &quot;&quot;&quot;</span>
    1618. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">DEBUG</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1619. <span class="k">def</span> <span class="nf">info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1620. <span class="sd">&quot;&quot;&quot;</span>
    1621. <span class="sd"> Delegate an info call to the underlying logger.</span>
    1622. <span class="sd"> &quot;&quot;&quot;</span>
    1623. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">INFO</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1624. <span class="k">def</span> <span class="nf">warning</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1625. <span class="sd">&quot;&quot;&quot;</span>
    1626. <span class="sd"> Delegate a warning call to the underlying logger.</span>
    1627. <span class="sd"> &quot;&quot;&quot;</span>
    1628. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">WARNING</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1629. <span class="k">def</span> <span class="nf">warn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1630. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The &#39;warn&#39; method is deprecated, &quot;</span>
    1631. <span class="s2">&quot;use &#39;warning&#39; instead&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    1632. <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1633. <span class="k">def</span> <span class="nf">error</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1634. <span class="sd">&quot;&quot;&quot;</span>
    1635. <span class="sd"> Delegate an error call to the underlying logger.</span>
    1636. <span class="sd"> &quot;&quot;&quot;</span>
    1637. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">ERROR</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1638. <span class="k">def</span> <span class="nf">exception</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1639. <span class="sd">&quot;&quot;&quot;</span>
    1640. <span class="sd"> Delegate an exception call to the underlying logger.</span>
    1641. <span class="sd"> &quot;&quot;&quot;</span>
    1642. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">ERROR</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1643. <span class="k">def</span> <span class="nf">critical</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1644. <span class="sd">&quot;&quot;&quot;</span>
    1645. <span class="sd"> Delegate a critical call to the underlying logger.</span>
    1646. <span class="sd"> &quot;&quot;&quot;</span>
    1647. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">CRITICAL</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1648. <span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1649. <span class="sd">&quot;&quot;&quot;</span>
    1650. <span class="sd"> Delegate a log call to the underlying logger, after adding</span>
    1651. <span class="sd"> contextual information from this adapter instance.</span>
    1652. <span class="sd"> &quot;&quot;&quot;</span>
    1653. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
    1654. <span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">)</span>
    1655. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1656. <span class="k">def</span> <span class="nf">isEnabledFor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
    1657. <span class="sd">&quot;&quot;&quot;</span>
    1658. <span class="sd"> Is this logger enabled for level &#39;level&#39;?</span>
    1659. <span class="sd"> &quot;&quot;&quot;</span>
    1660. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    1661. <span class="k">def</span> <span class="nf">setLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
    1662. <span class="sd">&quot;&quot;&quot;</span>
    1663. <span class="sd"> Set the specified level on the underlying logger.</span>
    1664. <span class="sd"> &quot;&quot;&quot;</span>
    1665. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    1666. <span class="k">def</span> <span class="nf">getEffectiveLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1667. <span class="sd">&quot;&quot;&quot;</span>
    1668. <span class="sd"> Get the effective level for the underlying logger.</span>
    1669. <span class="sd"> &quot;&quot;&quot;</span>
    1670. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">()</span>
    1671. <span class="k">def</span> <span class="nf">hasHandlers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1672. <span class="sd">&quot;&quot;&quot;</span>
    1673. <span class="sd"> See if the underlying logger has any handlers.</span>
    1674. <span class="sd"> &quot;&quot;&quot;</span>
    1675. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">hasHandlers</span><span class="p">()</span>
    1676. <span class="k">def</span> <span class="nf">_log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">extra</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">stack_info</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    1677. <span class="sd">&quot;&quot;&quot;</span>
    1678. <span class="sd"> Low-level log implementation, proxied to allow nested logger adapters.</span>
    1679. <span class="sd"> &quot;&quot;&quot;</span>
    1680. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span>
    1681. <span class="n">level</span><span class="p">,</span>
    1682. <span class="n">msg</span><span class="p">,</span>
    1683. <span class="n">args</span><span class="p">,</span>
    1684. <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span>
    1685. <span class="n">extra</span><span class="o">=</span><span class="n">extra</span><span class="p">,</span>
    1686. <span class="n">stack_info</span><span class="o">=</span><span class="n">stack_info</span><span class="p">,</span>
    1687. <span class="p">)</span>
    1688. <span class="nd">@property</span>
    1689. <span class="k">def</span> <span class="nf">manager</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1690. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">manager</span>
    1691. <span class="nd">@manager</span><span class="o">.</span><span class="n">setter</span>
    1692. <span class="k">def</span> <span class="nf">manager</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
    1693. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="n">value</span>
    1694. <span class="nd">@property</span>
    1695. <span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1696. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">name</span>
    1697. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1698. <span class="n">logger</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span>
    1699. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">())</span>
    1700. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
    1701. <span class="n">root</span> <span class="o">=</span> <span class="n">RootLogger</span><span class="p">(</span><span class="n">WARNING</span><span class="p">)</span>
    1702. <span class="n">Logger</span><span class="o">.</span><span class="n">root</span> <span class="o">=</span> <span class="n">root</span>
    1703. <span class="n">Logger</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="n">Manager</span><span class="p">(</span><span class="n">Logger</span><span class="o">.</span><span class="n">root</span><span class="p">)</span>
    1704. <span class="c1">#---------------------------------------------------------------------------</span>
    1705. <span class="c1"># Configuration classes and functions</span>
    1706. <span class="c1">#---------------------------------------------------------------------------</span>
    1707. <span class="k">def</span> <span class="nf">basicConfig</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1708. <span class="sd">&quot;&quot;&quot;</span>
    1709. <span class="sd"> Do basic configuration for the logging system.</span>
    1710. <span class="sd"> This function does nothing if the root logger already has handlers</span>
    1711. <span class="sd"> configured, unless the keyword argument *force* is set to ``True``.</span>
    1712. <span class="sd"> It is a convenience method intended for use by simple scripts</span>
    1713. <span class="sd"> to do one-shot configuration of the logging package.</span>
    1714. <span class="sd"> The default behaviour is to create a StreamHandler which writes to</span>
    1715. <span class="sd"> sys.stderr, set a formatter using the BASIC_FORMAT format string, and</span>
    1716. <span class="sd"> add the handler to the root logger.</span>
    1717. <span class="sd"> A number of optional keyword arguments may be specified, which can alter</span>
    1718. <span class="sd"> the default behaviour.</span>
    1719. <span class="sd"> filename Specifies that a FileHandler be created, using the specified</span>
    1720. <span class="sd"> filename, rather than a StreamHandler.</span>
    1721. <span class="sd"> filemode Specifies the mode to open the file, if filename is specified</span>
    1722. <span class="sd"> (if filemode is unspecified, it defaults to &#39;a&#39;).</span>
    1723. <span class="sd"> format Use the specified format string for the handler.</span>
    1724. <span class="sd"> datefmt Use the specified date/time format.</span>
    1725. <span class="sd"> style If a format string is specified, use this to specify the</span>
    1726. <span class="sd"> type of format string (possible values &#39;%&#39;, &#39;{&#39;, &#39;$&#39;, for</span>
    1727. <span class="sd"> %-formatting, :meth:`str.format` and :class:`string.Template`</span>
    1728. <span class="sd"> - defaults to &#39;%&#39;).</span>
    1729. <span class="sd"> level Set the root logger level to the specified level.</span>
    1730. <span class="sd"> stream Use the specified stream to initialize the StreamHandler. Note</span>
    1731. <span class="sd"> that this argument is incompatible with &#39;filename&#39; - if both</span>
    1732. <span class="sd"> are present, &#39;stream&#39; is ignored.</span>
    1733. <span class="sd"> handlers If specified, this should be an iterable of already created</span>
    1734. <span class="sd"> handlers, which will be added to the root handler. Any handler</span>
    1735. <span class="sd"> in the list which does not have a formatter assigned will be</span>
    1736. <span class="sd"> assigned the formatter created in this function.</span>
    1737. <span class="sd"> force If this keyword is specified as true, any existing handlers</span>
    1738. <span class="sd"> attached to the root logger are removed and closed, before</span>
    1739. <span class="sd"> carrying out the configuration as specified by the other</span>
    1740. <span class="sd"> arguments.</span>
    1741. <span class="sd"> encoding If specified together with a filename, this encoding is passed to</span>
    1742. <span class="sd"> the created FileHandler, causing it to be used when the file is</span>
    1743. <span class="sd"> opened.</span>
    1744. <span class="sd"> errors If specified together with a filename, this value is passed to the</span>
    1745. <span class="sd"> created FileHandler, causing it to be used when the file is</span>
    1746. <span class="sd"> opened in text mode. If not specified, the default value is</span>
    1747. <span class="sd"> `backslashreplace`.</span>
    1748. <span class="sd"> Note that you could specify a stream created using open(filename, mode)</span>
    1749. <span class="sd"> rather than passing the filename and mode in. However, it should be</span>
    1750. <span class="sd"> remembered that StreamHandler does not close its stream (since it may be</span>
    1751. <span class="sd"> using sys.stdout or sys.stderr), whereas FileHandler closes its stream</span>
    1752. <span class="sd"> when the handler is closed.</span>
    1753. <span class="sd"> .. versionchanged:: 3.2</span>
    1754. <span class="sd"> Added the ``style`` parameter.</span>
    1755. <span class="sd"> .. versionchanged:: 3.3</span>
    1756. <span class="sd"> Added the ``handlers`` parameter. A ``ValueError`` is now thrown for</span>
    1757. <span class="sd"> incompatible arguments (e.g. ``handlers`` specified together with</span>
    1758. <span class="sd"> ``filename``/``filemode``, or ``filename``/``filemode`` specified</span>
    1759. <span class="sd"> together with ``stream``, or ``handlers`` specified together with</span>
    1760. <span class="sd"> ``stream``.</span>
    1761. <span class="sd"> .. versionchanged:: 3.8</span>
    1762. <span class="sd"> Added the ``force`` parameter.</span>
    1763. <span class="sd"> .. versionchanged:: 3.9</span>
    1764. <span class="sd"> Added the ``encoding`` and ``errors`` parameters.</span>
    1765. <span class="sd"> &quot;&quot;&quot;</span>
    1766. <span class="c1"># Add thread safety in case someone mistakenly calls</span>
    1767. <span class="c1"># basicConfig() from multiple threads</span>
    1768. <span class="n">_acquireLock</span><span class="p">()</span>
    1769. <span class="k">try</span><span class="p">:</span>
    1770. <span class="n">force</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;force&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
    1771. <span class="n">encoding</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;encoding&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    1772. <span class="n">errors</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;errors&#39;</span><span class="p">,</span> <span class="s1">&#39;backslashreplace&#39;</span><span class="p">)</span>
    1773. <span class="k">if</span> <span class="n">force</span><span class="p">:</span>
    1774. <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">[:]:</span>
    1775. <span class="n">root</span><span class="o">.</span><span class="n">removeHandler</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
    1776. <span class="n">h</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    1777. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    1778. <span class="n">handlers</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;handlers&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    1779. <span class="k">if</span> <span class="n">handlers</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    1780. <span class="k">if</span> <span class="s2">&quot;stream&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span> <span class="ow">and</span> <span class="s2">&quot;filename&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
    1781. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;&#39;stream&#39; and &#39;filename&#39; should not be &quot;</span>
    1782. <span class="s2">&quot;specified together&quot;</span><span class="p">)</span>
    1783. <span class="k">else</span><span class="p">:</span>
    1784. <span class="k">if</span> <span class="s2">&quot;stream&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span> <span class="ow">or</span> <span class="s2">&quot;filename&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
    1785. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;&#39;stream&#39; or &#39;filename&#39; should not be &quot;</span>
    1786. <span class="s2">&quot;specified together with &#39;handlers&#39;&quot;</span><span class="p">)</span>
    1787. <span class="k">if</span> <span class="n">handlers</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    1788. <span class="n">filename</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;filename&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    1789. <span class="n">mode</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;filemode&quot;</span><span class="p">,</span> <span class="s1">&#39;a&#39;</span><span class="p">)</span>
    1790. <span class="k">if</span> <span class="n">filename</span><span class="p">:</span>
    1791. <span class="k">if</span> <span class="s1">&#39;b&#39;</span><span class="ow">in</span> <span class="n">mode</span><span class="p">:</span>
    1792. <span class="n">errors</span> <span class="o">=</span> <span class="kc">None</span>
    1793. <span class="n">h</span> <span class="o">=</span> <span class="n">FileHandler</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">mode</span><span class="p">,</span>
    1794. <span class="n">encoding</span><span class="o">=</span><span class="n">encoding</span><span class="p">,</span> <span class="n">errors</span><span class="o">=</span><span class="n">errors</span><span class="p">)</span>
    1795. <span class="k">else</span><span class="p">:</span>
    1796. <span class="n">stream</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;stream&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    1797. <span class="n">h</span> <span class="o">=</span> <span class="n">StreamHandler</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
    1798. <span class="n">handlers</span> <span class="o">=</span> <span class="p">[</span><span class="n">h</span><span class="p">]</span>
    1799. <span class="n">dfs</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;datefmt&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    1800. <span class="n">style</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;style&quot;</span><span class="p">,</span> <span class="s1">&#39;%&#39;</span><span class="p">)</span>
    1801. <span class="k">if</span> <span class="n">style</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_STYLES</span><span class="p">:</span>
    1802. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Style must be one of: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="s1">&#39;,&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
    1803. <span class="n">_STYLES</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
    1804. <span class="n">fs</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;format&quot;</span><span class="p">,</span> <span class="n">_STYLES</span><span class="p">[</span><span class="n">style</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
    1805. <span class="n">fmt</span> <span class="o">=</span> <span class="n">Formatter</span><span class="p">(</span><span class="n">fs</span><span class="p">,</span> <span class="n">dfs</span><span class="p">,</span> <span class="n">style</span><span class="p">)</span>
    1806. <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">handlers</span><span class="p">:</span>
    1807. <span class="k">if</span> <span class="n">h</span><span class="o">.</span><span class="n">formatter</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    1808. <span class="n">h</span><span class="o">.</span><span class="n">setFormatter</span><span class="p">(</span><span class="n">fmt</span><span class="p">)</span>
    1809. <span class="n">root</span><span class="o">.</span><span class="n">addHandler</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
    1810. <span class="n">level</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;level&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    1811. <span class="k">if</span> <span class="n">level</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    1812. <span class="n">root</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    1813. <span class="k">if</span> <span class="n">kwargs</span><span class="p">:</span>
    1814. <span class="n">keys</span> <span class="o">=</span> <span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">kwargs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
    1815. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Unrecognised argument(s): </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">keys</span><span class="p">)</span>
    1816. <span class="k">finally</span><span class="p">:</span>
    1817. <span class="n">_releaseLock</span><span class="p">()</span>
    1818. <span class="c1">#---------------------------------------------------------------------------</span>
    1819. <span class="c1"># Utility functions at module level.</span>
    1820. <span class="c1"># Basically delegate everything to the root logger.</span>
    1821. <span class="c1">#---------------------------------------------------------------------------</span>
    1822. <span class="k">def</span> <span class="nf">getLogger</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    1823. <span class="sd">&quot;&quot;&quot;</span>
    1824. <span class="sd"> Return a logger with the specified name, creating it if necessary.</span>
    1825. <span class="sd"> If no name is specified, return the root logger.</span>
    1826. <span class="sd"> &quot;&quot;&quot;</span>
    1827. <span class="k">if</span> <span class="ow">not</span> <span class="n">name</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">name</span> <span class="o">==</span> <span class="n">root</span><span class="o">.</span><span class="n">name</span><span class="p">:</span>
    1828. <span class="k">return</span> <span class="n">root</span>
    1829. <span class="k">return</span> <span class="n">Logger</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
    1830. <span class="k">def</span> <span class="nf">critical</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1831. <span class="sd">&quot;&quot;&quot;</span>
    1832. <span class="sd"> Log a message with severity &#39;CRITICAL&#39; on the root logger. If the logger</span>
    1833. <span class="sd"> has no handlers, call basicConfig() to add a console handler with a</span>
    1834. <span class="sd"> pre-defined format.</span>
    1835. <span class="sd"> &quot;&quot;&quot;</span>
    1836. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    1837. <span class="n">basicConfig</span><span class="p">()</span>
    1838. <span class="n">root</span><span class="o">.</span><span class="n">critical</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1839. <span class="n">fatal</span> <span class="o">=</span> <span class="n">critical</span>
    1840. <span class="k">def</span> <span class="nf">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1841. <span class="sd">&quot;&quot;&quot;</span>
    1842. <span class="sd"> Log a message with severity &#39;ERROR&#39; on the root logger. If the logger has</span>
    1843. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
    1844. <span class="sd"> format.</span>
    1845. <span class="sd"> &quot;&quot;&quot;</span>
    1846. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    1847. <span class="n">basicConfig</span><span class="p">()</span>
    1848. <span class="n">root</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1849. <span class="k">def</span> <span class="nf">exception</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1850. <span class="sd">&quot;&quot;&quot;</span>
    1851. <span class="sd"> Log a message with severity &#39;ERROR&#39; on the root logger, with exception</span>
    1852. <span class="sd"> information. If the logger has no handlers, basicConfig() is called to add</span>
    1853. <span class="sd"> a console handler with a pre-defined format.</span>
    1854. <span class="sd"> &quot;&quot;&quot;</span>
    1855. <span class="n">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1856. <span class="k">def</span> <span class="nf">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1857. <span class="sd">&quot;&quot;&quot;</span>
    1858. <span class="sd"> Log a message with severity &#39;WARNING&#39; on the root logger. If the logger has</span>
    1859. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
    1860. <span class="sd"> format.</span>
    1861. <span class="sd"> &quot;&quot;&quot;</span>
    1862. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    1863. <span class="n">basicConfig</span><span class="p">()</span>
    1864. <span class="n">root</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1865. <span class="k">def</span> <span class="nf">warn</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1866. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The &#39;warn&#39; function is deprecated, &quot;</span>
    1867. <span class="s2">&quot;use &#39;warning&#39; instead&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    1868. <span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1869. <span class="k">def</span> <span class="nf">info</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1870. <span class="sd">&quot;&quot;&quot;</span>
    1871. <span class="sd"> Log a message with severity &#39;INFO&#39; on the root logger. If the logger has</span>
    1872. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
    1873. <span class="sd"> format.</span>
    1874. <span class="sd"> &quot;&quot;&quot;</span>
    1875. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    1876. <span class="n">basicConfig</span><span class="p">()</span>
    1877. <span class="n">root</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1878. <span class="k">def</span> <span class="nf">debug</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1879. <span class="sd">&quot;&quot;&quot;</span>
    1880. <span class="sd"> Log a message with severity &#39;DEBUG&#39; on the root logger. If the logger has</span>
    1881. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
    1882. <span class="sd"> format.</span>
    1883. <span class="sd"> &quot;&quot;&quot;</span>
    1884. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    1885. <span class="n">basicConfig</span><span class="p">()</span>
    1886. <span class="n">root</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1887. <span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    1888. <span class="sd">&quot;&quot;&quot;</span>
    1889. <span class="sd"> Log &#39;msg % args&#39; with the integer severity &#39;level&#39; on the root logger. If</span>
    1890. <span class="sd"> the logger has no handlers, call basicConfig() to add a console handler</span>
    1891. <span class="sd"> with a pre-defined format.</span>
    1892. <span class="sd"> &quot;&quot;&quot;</span>
    1893. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    1894. <span class="n">basicConfig</span><span class="p">()</span>
    1895. <span class="n">root</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    1896. <span class="k">def</span> <span class="nf">disable</span><span class="p">(</span><span class="n">level</span><span class="o">=</span><span class="n">CRITICAL</span><span class="p">):</span>
    1897. <span class="sd">&quot;&quot;&quot;</span>
    1898. <span class="sd"> Disable all logging calls of severity &#39;level&#39; and below.</span>
    1899. <span class="sd"> &quot;&quot;&quot;</span>
    1900. <span class="n">root</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">disable</span> <span class="o">=</span> <span class="n">level</span>
    1901. <span class="n">root</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">_clear_cache</span><span class="p">()</span>
    1902. <span class="k">def</span> <span class="nf">shutdown</span><span class="p">(</span><span class="n">handlerList</span><span class="o">=</span><span class="n">_handlerList</span><span class="p">):</span>
    1903. <span class="sd">&quot;&quot;&quot;</span>
    1904. <span class="sd"> Perform any cleanup actions in the logging system (e.g. flushing</span>
    1905. <span class="sd"> buffers).</span>
    1906. <span class="sd"> Should be called at application exit.</span>
    1907. <span class="sd"> &quot;&quot;&quot;</span>
    1908. <span class="k">for</span> <span class="n">wr</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="n">handlerList</span><span class="p">[:]):</span>
    1909. <span class="c1">#errors might occur, for example, if files are locked</span>
    1910. <span class="c1">#we just ignore them if raiseExceptions is not set</span>
    1911. <span class="k">try</span><span class="p">:</span>
    1912. <span class="n">h</span> <span class="o">=</span> <span class="n">wr</span><span class="p">()</span>
    1913. <span class="k">if</span> <span class="n">h</span><span class="p">:</span>
    1914. <span class="k">try</span><span class="p">:</span>
    1915. <span class="n">h</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
    1916. <span class="n">h</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    1917. <span class="n">h</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    1918. <span class="k">except</span> <span class="p">(</span><span class="ne">OSError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">):</span>
    1919. <span class="c1"># Ignore errors which might be caused</span>
    1920. <span class="c1"># because handlers have been closed but</span>
    1921. <span class="c1"># references to them are still around at</span>
    1922. <span class="c1"># application exit.</span>
    1923. <span class="k">pass</span>
    1924. <span class="k">finally</span><span class="p">:</span>
    1925. <span class="n">h</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
    1926. <span class="k">except</span><span class="p">:</span> <span class="c1"># ignore everything, as we&#39;re shutting down</span>
    1927. <span class="k">if</span> <span class="n">raiseExceptions</span><span class="p">:</span>
    1928. <span class="k">raise</span>
    1929. <span class="c1">#else, swallow</span>
    1930. <span class="c1">#Let&#39;s try and shutdown automatically on application exit...</span>
    1931. <span class="kn">import</span> <span class="nn">atexit</span>
    1932. <span class="n">atexit</span><span class="o">.</span><span class="n">register</span><span class="p">(</span><span class="n">shutdown</span><span class="p">)</span>
    1933. <span class="c1"># Null handler</span>
    1934. <span class="k">class</span> <span class="nc">NullHandler</span><span class="p">(</span><span class="n">Handler</span><span class="p">):</span>
    1935. <span class="sd">&quot;&quot;&quot;</span>
    1936. <span class="sd"> This handler does nothing. It&#39;s intended to be used to avoid the</span>
    1937. <span class="sd"> &quot;No handlers could be found for logger XXX&quot; one-off warning. This is</span>
    1938. <span class="sd"> important for library code, which may contain code to log events. If a user</span>
    1939. <span class="sd"> of the library does not configure logging, the one-off warning might be</span>
    1940. <span class="sd"> produced; to avoid this, the library developer simply needs to instantiate</span>
    1941. <span class="sd"> a NullHandler and add it to the top-level logger of the library module or</span>
    1942. <span class="sd"> package.</span>
    1943. <span class="sd"> &quot;&quot;&quot;</span>
    1944. <span class="k">def</span> <span class="nf">handle</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    1945. <span class="sd">&quot;&quot;&quot;Stub.&quot;&quot;&quot;</span>
    1946. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
    1947. <span class="sd">&quot;&quot;&quot;Stub.&quot;&quot;&quot;</span>
    1948. <span class="k">def</span> <span class="nf">createLock</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1949. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span> <span class="o">=</span> <span class="kc">None</span>
    1950. <span class="k">def</span> <span class="nf">_at_fork_reinit</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    1951. <span class="k">pass</span>
    1952. <span class="c1"># Warnings integration</span>
    1953. <span class="n">_warnings_showwarning</span> <span class="o">=</span> <span class="kc">None</span>
    1954. <span class="k">def</span> <span class="nf">_showwarning</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">category</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">line</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    1955. <span class="sd">&quot;&quot;&quot;</span>
    1956. <span class="sd"> Implementation of showwarnings which redirects to logging, which will first</span>
    1957. <span class="sd"> check to see if the file parameter is None. If a file is specified, it will</span>
    1958. <span class="sd"> delegate to the original warnings implementation of showwarning. Otherwise,</span>
    1959. <span class="sd"> it will call warnings.formatwarning and will log the resulting string to a</span>
    1960. <span class="sd"> warnings logger named &quot;py.warnings&quot; with level logging.WARNING.</span>
    1961. <span class="sd"> &quot;&quot;&quot;</span>
    1962. <span class="k">if</span> <span class="n">file</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    1963. <span class="k">if</span> <span class="n">_warnings_showwarning</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    1964. <span class="n">_warnings_showwarning</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">category</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span> <span class="n">file</span><span class="p">,</span> <span class="n">line</span><span class="p">)</span>
    1965. <span class="k">else</span><span class="p">:</span>
    1966. <span class="n">s</span> <span class="o">=</span> <span class="n">warnings</span><span class="o">.</span><span class="n">formatwarning</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">category</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span> <span class="n">line</span><span class="p">)</span>
    1967. <span class="n">logger</span> <span class="o">=</span> <span class="n">getLogger</span><span class="p">(</span><span class="s2">&quot;py.warnings&quot;</span><span class="p">)</span>
    1968. <span class="k">if</span> <span class="ow">not</span> <span class="n">logger</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
    1969. <span class="n">logger</span><span class="o">.</span><span class="n">addHandler</span><span class="p">(</span><span class="n">NullHandler</span><span class="p">())</span>
    1970. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">s</span><span class="p">)</span>
    1971. <span class="k">def</span> <span class="nf">captureWarnings</span><span class="p">(</span><span class="n">capture</span><span class="p">):</span>
    1972. <span class="sd">&quot;&quot;&quot;</span>
    1973. <span class="sd"> If capture is true, redirect all warnings to the logging package.</span>
    1974. <span class="sd"> If capture is False, ensure that warnings are not redirected to logging</span>
    1975. <span class="sd"> but to their original destinations.</span>
    1976. <span class="sd"> &quot;&quot;&quot;</span>
    1977. <span class="k">global</span> <span class="n">_warnings_showwarning</span>
    1978. <span class="k">if</span> <span class="n">capture</span><span class="p">:</span>
    1979. <span class="k">if</span> <span class="n">_warnings_showwarning</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    1980. <span class="n">_warnings_showwarning</span> <span class="o">=</span> <span class="n">warnings</span><span class="o">.</span><span class="n">showwarning</span>
    1981. <span class="n">warnings</span><span class="o">.</span><span class="n">showwarning</span> <span class="o">=</span> <span class="n">_showwarning</span>
    1982. <span class="k">else</span><span class="p">:</span>
    1983. <span class="k">if</span> <span class="n">_warnings_showwarning</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    1984. <span class="n">warnings</span><span class="o">.</span><span class="n">showwarning</span> <span class="o">=</span> <span class="n">_warnings_showwarning</span>
    1985. <span class="n">_warnings_showwarning</span> <span class="o">=</span> <span class="kc">None</span>
    1986. </pre></div>
    1987. </div>
    1988. </div>
    1989. <footer>
    1990. <hr/>
    1991. <div role="contentinfo">
    1992. <p>&#169; Copyright 2021, SuperGradients team.</p>
    1993. </div>
    1994. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    1995. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    1996. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    1997. </footer>
    1998. </div>
    1999. </div>
    2000. </section>
    2001. </div>
    2002. <script>
    2003. jQuery(function () {
    2004. SphinxRtdTheme.Navigation.enable(true);
    2005. });
    2006. </script>
    2007. </body>
    2008. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.common.abstractions.abstract_logger &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.common.abstractions.abstract_logger</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.common.abstractions.abstract_logger</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">logging</span>
    84. <span class="kn">import</span> <span class="nn">logging.config</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.common.auto_logging</span> <span class="kn">import</span> <span class="n">AutoLoggerConfig</span>
    86. <span class="kn">from</span> <span class="nn">super_gradients.common.environment.environment_config</span> <span class="kn">import</span> <span class="n">DEFAULT_LOGGING_LEVEL</span>
    87. <div class="viewcode-block" id="get_logger"><a class="viewcode-back" href="../../../../super_gradients.common.abstractions.html#super_gradients.common.abstractions.abstract_logger.get_logger">[docs]</a><span class="k">def</span> <span class="nf">get_logger</span><span class="p">(</span>
    88. <span class="n">logger_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">training_log_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">logs_dir_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">log_level</span><span class="o">=</span><span class="n">DEFAULT_LOGGING_LEVEL</span>
    89. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">logging</span><span class="o">.</span><span class="n">Logger</span><span class="p">:</span>
    90. <span class="n">config_dict</span> <span class="o">=</span> <span class="n">AutoLoggerConfig</span><span class="o">.</span><span class="n">generate_config_for_module_name</span><span class="p">(</span>
    91. <span class="n">module_name</span><span class="o">=</span><span class="n">logger_name</span><span class="p">,</span> <span class="n">training_log_path</span><span class="o">=</span><span class="n">training_log_path</span><span class="p">,</span> <span class="n">logs_dir_path</span><span class="o">=</span><span class="n">logs_dir_path</span><span class="p">,</span> <span class="n">log_level</span><span class="o">=</span><span class="n">log_level</span>
    92. <span class="p">)</span>
    93. <span class="n">logging</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dictConfig</span><span class="p">(</span><span class="n">config_dict</span><span class="p">)</span>
    94. <span class="n">logger</span><span class="p">:</span> <span class="n">logging</span><span class="o">.</span><span class="n">Logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">logger_name</span><span class="p">)</span>
    95. <span class="k">return</span> <span class="n">logger</span></div>
    96. <div class="viewcode-block" id="ILogger"><a class="viewcode-back" href="../../../../super_gradients.common.abstractions.html#super_gradients.common.abstractions.abstract_logger.ILogger">[docs]</a><span class="k">class</span> <span class="nc">ILogger</span><span class="p">:</span>
    97. <span class="sd">&quot;&quot;&quot;</span>
    98. <span class="sd"> Provides logging capabilities to the derived class.</span>
    99. <span class="sd"> &quot;&quot;&quot;</span>
    100. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logger_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    101. <span class="n">logger_name</span> <span class="o">=</span> <span class="n">logger_name</span> <span class="k">if</span> <span class="n">logger_name</span> <span class="k">else</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__module__</span><span class="p">)</span>
    102. <span class="bp">self</span><span class="o">.</span><span class="n">_logger</span><span class="p">:</span> <span class="n">logging</span><span class="o">.</span><span class="n">Logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="n">logger_name</span><span class="p">)</span></div>
    103. </pre></div>
    104. </div>
    105. </div>
    106. <footer>
    107. <hr/>
    108. <div role="contentinfo">
    109. <p>&#169; Copyright 2021, SuperGradients team.</p>
    110. </div>
    111. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    112. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    113. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    114. </footer>
    115. </div>
    116. </div>
    117. </section>
    118. </div>
    119. <script>
    120. jQuery(function () {
    121. SphinxRtdTheme.Navigation.enable(true);
    122. });
    123. </script>
    124. </body>
    125. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.auto_logging.auto_logger &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.auto_logging.auto_logger &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -86,116 +88,130 @@
                <div itemprop="articleBody">
                <div itemprop="articleBody">
                  
                  
       <h1>Source code for super_gradients.common.auto_logging.auto_logger</h1><div class="highlight"><pre>
       <h1>Source code for super_gradients.common.auto_logging.auto_logger</h1><div class="highlight"><pre>
    -<span></span><span class="kn">import</span> <span class="nn">json</span>
    +<span></span><span class="kn">import</span> <span class="nn">logging</span>
     <span class="kn">import</span> <span class="nn">os</span>
     <span class="kn">import</span> <span class="nn">os</span>
    +<span class="kn">import</span> <span class="nn">sys</span>
    +<span class="kn">import</span> <span class="nn">time</span>
    +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
     
     
    -<span class="kn">import</span> <span class="nn">pkg_resources</span>
    -
    -<span class="kn">from</span> <span class="nn">super_gradients.common.environment.environment_config</span> <span class="kn">import</span> <span class="n">DEFAULT_LOGGING_LEVEL</span>
     
     
    -
    -<div class="viewcode-block" id="AutoLoggerConfig"><a class="viewcode-back" href="../../../../super_gradients.common.auto_logging.html#super_gradients.common.auto_logging.auto_logger.AutoLoggerConfig">[docs]</a><span class="k">class</span> <span class="nc">AutoLoggerConfig</span><span class="p">:</span>
    +<div class="viewcode-block" id="AutoLoggerConfig"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig">[docs]</a><span class="k">class</span> <span class="nc">AutoLoggerConfig</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">    A Class for the Automated Logging Config that is created from the JSON config file (auto_logging_conf)</span>
    +<span class="sd">    A Class for the Automated Logging Config</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
    -<div class="viewcode-block" id="AutoLoggerConfig.generate_config_for_module_name"><a class="viewcode-back" href="../../../../super_gradients.common.auto_logging.html#super_gradients.common.auto_logging.auto_logger.AutoLoggerConfig.generate_config_for_module_name">[docs]</a>    <span class="nd">@staticmethod</span>
    -    <span class="k">def</span> <span class="nf">generate_config_for_module_name</span><span class="p">(</span>
    -        <span class="n">module_name</span><span class="p">,</span>
    -        <span class="n">training_log_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    -        <span class="n">log_level</span><span class="o">=</span><span class="n">DEFAULT_LOGGING_LEVEL</span><span class="p">,</span>
    -        <span class="n">max_bytes</span><span class="o">=</span><span class="mi">10485760</span><span class="p">,</span>
    -        <span class="n">logs_dir_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    -        <span class="n">handlers_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    -    <span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
    +    <span class="n">FILE_LOGGING_LEVEL</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;LOG_LEVEL&quot;</span><span class="p">,</span> <span class="s2">&quot;DEBUG&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span>
    +    <span class="n">CONSOLE_LOGGING_LEVEL</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;CONSOLE_LOG_LEVEL&quot;</span><span class="p">,</span> <span class="s2">&quot;INFO&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span>
    +
    +    <span class="n">filename</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
    +
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="kc">None</span>
    +
    +    <span class="k">def</span> <span class="nf">_setup_default_logging</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_level</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        generate_config_for_module_name - Returns a Config Dict For Logging</span>
    -<span class="sd">            :param module_name:     The Python Module name to create auto_logging for</span>
    -<span class="sd">            :param log_level:       Minimal log level to set for the new auto_logging</span>
    -<span class="sd">            :param max_bytes:       Max size for the log file before rotation starts</span>
    -<span class="sd">            :param handlers_list:    A list specifying the handlers (Console, etc..) - Better Leave Empty or None</span>
    -<span class="sd">            :param training_log_path: Path to training log file which all modules of super_gradients will write to. Ignored</span>
    -<span class="sd">             when set to None.</span>
    -<span class="sd">            :param logs_dir_path: Path to sg_logs directory (default=None), where module logs will be saved. When set</span>
    -<span class="sd">                to None- module logs will be saved in ~/sg_logs (created if path does not exist). Main use case is for</span>
    -<span class="sd">                testing.</span>
    -
    -
    -<span class="sd">            :return: python dict() with the new auto_logging for the module</span>
    +<span class="sd">        Setup default logging configuration. Usually happens when app starts, and we don&#39;t have</span>
    +<span class="sd">        experiment dir yet.</span>
    +<span class="sd">        The default log directory will be `~/sg_logs`</span>
    +<span class="sd">        :param log_level: The default log level to use. If None, uses LOG_LEVEL and CONSOLE_LOG_LEVEL environment vars.</span>
    +<span class="sd">        :return: None</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     
     
    -        <span class="c1"># LOADING THE ORIGINAL ROOT CONFIG FILE</span>
    -        <span class="n">conf_file_name</span> <span class="o">=</span> <span class="s2">&quot;auto_logging_conf.json&quot;</span>
    -        <span class="n">conf_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
    -            <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;super_gradients&quot;</span><span class="p">,</span> <span class="s2">&quot;/common/auto_logging/&quot;</span><span class="p">),</span> <span class="n">conf_file_name</span>
    +        <span class="c1"># There is no _easy_ way to log all events to a single file, when using DDP or DataLoader with num_workers &gt; 1</span>
    +        <span class="c1"># on Windows platform. In both these cases a multiple processes will be spawned and multiple logs may be created.</span>
    +        <span class="c1"># Therefore the log file will have the parent PID to being able to discriminate the logs corresponding to a single run.</span>
    +        <span class="n">timestamp</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;%Y_%m_</span><span class="si">%d</span><span class="s2">_%H_%M_%S&quot;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">localtime</span><span class="p">())</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_setup_logging</span><span class="p">(</span>
    +            <span class="n">filename</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;~/sg_logs/sg_logs_</span><span class="si">{</span><span class="n">os</span><span class="o">.</span><span class="n">getppid</span><span class="p">()</span><span class="si">}</span><span class="s2">_</span><span class="si">{</spa
    +            <span class="n">copy_already_logged_messages</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    +            <span class="n">filemode</span><span class="o">=</span><span class="s2">&quot;w&quot;</span><span class="p">,</span>
    +            <span class="n">log_level</span><span class="o">=</span><span class="n">log_level</span><span class="p">,</span>
    +        <span class="p">)</span>
    +
    +    <span class="k">def</span> <span class="nf">_setup_logging</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">copy_already_logged_messages</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">filemode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span c
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Sets the logging configuration to store messages to specific file</span>
    +<span class="sd">        :param filename: Output log file</span>
    +<span class="sd">        :param filemode: Open mode for file</span>
    +<span class="sd">        :param copy_already_logged_messages: Controls whether messages from previous log configuration should be copied</span>
    +<span class="sd">               to new place. This is helpful to transfer diagnostic messages (from the app start) to experiment dir.</span>
    +<span class="sd">        :param log_level: The default log level to use. If None, uses LOG_LEVEL and CONSOLE_LOG_LEVEL environment vars.</span>
    +<span class="sd">        :return:</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +
    +        <span class="k">if</span> <span class="n">copy_already_logged_messages</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="bp">self</span><span cl
    +            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">src</span><span class="p">:</span>
    +                <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">dst</span><span class="p">:</span>
    +                    <span class="n">dst</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
    +
    +        <span class="n">file_logging_level</span> <span class="o">=</span> <span class="n">log_level</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">FILE_LOGGING_LEVEL</span>
    +        <span class="n">console_logging_level</span> <span class="o">=</span> <span class="n">log_level</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">CONSOLE_LOGGING_LEVEL</span>
    +
    +        <span class="n">cur_version</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">version_info</span>
    +        <span class="n">python_38</span> <span class="o">=</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
    +        <span class="n">python_39</span> <span class="o">=</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
    +        <span class="n">manager</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">manager</span>
    +
    +        <span class="n">extra_kwargs</span> <span class="o">=</span> <span class="p">{}</span>
    +        <span class="k">if</span> <span class="n">cur_version</span> <span class="o">&gt;=</span> <span class="n">python_38</span><span class="p">:</span>
    +            <span class="n">extra_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
    +                <span class="n">force</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    +            <span class="p">)</span>
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="c1"># If the logging does not support force=True, we should manually delete handlers</span>
    +            <span class="k">del</span> <span class="n">manager</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">[:]</span>
    +
    +        <span class="k">if</span> <span class="n">cur_version</span> <span class="o">&gt;=</span> <span class="n">python_39</span><span class="p">:</span>
    +            <span class="n">extra_kwargs</span><span class="p">[</span><span class="s2">&quot;encoding&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;utf-8&quot;</span>
    +
    +        <span class="n">logging</span><span class="o">.</span><span class="n">basicConfig</span><span class="p">(</span>
    +            <span class="n">filename</span><span class="o">=</span><span class="n">filename</span><span class="p">,</span>
    +            <span class="n">filemode</span><span class="o">=</span><span class="n">filemode</span><span class="p">,</span>
    +            <span class="nb">format</span><span class="o">=</span><span class="s2">&quot;</span><span class="si">%(asctime)s</span><span class="s2"> </span><span class="si">%(levelname)s</span><span class="s2"> - </span><span class="si">%(name)s</span><span class="s2"> - </span><span class="si">%(message)s</span><span class="s2">&quot;</span><span class="p">,</span>
    +            <span class="n">datefmt</span><span class="o">=</span><span class="s2">&quot;[%Y-%m-</span><span class="si">%d</span><span class="s2"> %H:%M:%S]&quot;</span><span class="p">,</span>
    +            <span class="n">level</span><span class="o">=</span><span class="n">file_logging_level</span><span class="p">,</span>
    +            <span class="o">**</span><span class="n">extra_kwargs</span><span class="p">,</span>
             <span class="p">)</span>
             <span class="p">)</span>
     
     
    -        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">conf_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">logging_configuration_file</span><span class="p">:</span>
    -            <span class="n">config_dict</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">logging_configuration_file</span><span class="p">)</span>
    -
    -        <span class="c1"># CREATING THE PATH TO THE &quot;HOME&quot; FOLDER WITH THE LOG FILE NAME</span>
    -        <span class="k">if</span> <span class="ow">not</span> <span class="n">logs_dir_path</span><span class="p">:</span>
    -            <span class="n">log_file_name</span> <span class="o">=</span> <span class="n">module_name</span> <span class="o">+</span> <span class="s2">&quot;.log&quot;</span>
    -            <span class="n">user_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="sa">r</span><span class="s2">&quot;~&quot;</span><span class="p">)</span>
    -            <span class="n">logs_dir_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">user_dir</span><span class="p">,</span> <span class="s2">&quot;sg_logs&quot;</span><span class="p">)</span>
    -
    -        <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">logs_dir_path</span><span class="p">):</span>
    -            <span class="k">try</span><span class="p">:</span>
    -                <span class="n">os</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">logs_dir_path</span><span class="p">)</span>
    -            <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
    -                <span class="nb">print</span><span class="p">(</span>
    -                    <span class="s2">&quot;[WARNING] - sg_logs folder was not found and couldn&#39;t be created from the code - &quot;</span>
    -                    <span class="s2">&quot;All of the Log output will be sent to Console!&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">)</span>
    -                <span class="p">)</span>
    -
    -            <span class="c1"># HANDLERS LIST IS EMPTY AS CONSOLE IS ONLY ROOT HANDLER BECAUSE MODULE LOGGERS PROPAGATE THEIR LOGS UP.</span>
    -            <span class="n">handlers_list</span> <span class="o">=</span> <span class="p">[]</span>
    -            <span class="n">logger</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span> <span class="s2">&quot;handlers&quot;</span><span class="p">:</span> <span class="n">handlers_list</span><span class="p">,</span> <span class="s2">&quot;propagate&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span>
    -            <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;loggers&quot;</span><span class="p">][</span><span class="n">module_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">logger</span>
    -
    -            <span class="k">return</span> <span class="n">config_dict</span>
    -
    -        <span class="n">log_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">logs_dir_path</span><span class="p">,</span> <span class="n">log_file_name</span><span class="p">)</span>
    -
    -        <span class="c1"># THE ENTRIES TO ADD TO THE ORIGINAL CONFIGURATION</span>
    -        <span class="n">handler_name</span> <span class="o">=</span> <span class="n">module_name</span> <span class="o">+</span> <span class="s2">&quot;_file_handler&quot;</span>
    -        <span class="n">file_handler</span> <span class="o">=</span> <span class="p">{</span>
    -            <span class="s2">&quot;class&quot;</span><span class="p">:</span> <span class="s2">&quot;logging.handlers.RotatingFileHandler&quot;</span><span class="p">,</span>
    -            <span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span>
    -            <span class="s2">&quot;formatter&quot;</span><span class="p">:</span> <span class="s2">&quot;fileFormatter&quot;</span><span class="p">,</span>
    -            <span class="s2">&quot;filename&quot;</span><span class="p">:</span> <span class="n">log_file_path</span><span class="p">,</span>
    -            <span class="s2">&quot;maxBytes&quot;</span><span class="p">:</span> <span class="n">max_bytes</span><span class="p">,</span>
    -            <span class="s2">&quot;backupCount&quot;</span><span class="p">:</span> <span class="mi">20</span><span class="p">,</span>
    -            <span class="s2">&quot;encoding&quot;</span><span class="p">:</span> <span class="s2">&quot;utf8&quot;</span><span class="p">,</span>
    -        <span class="p">}</span>
    -
    -        <span class="c1"># CREATING ONLY A FILE HANDLER, CONSOLE IS ONLY ROOT HANDLER AS MODULE LOGGERS PROPAGATE THEIR LOGS UP.</span>
    -        <span class="k">if</span> <span class="n">handlers_list</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">handlers_list</span><span class="o">.</span><span class="n">empty</span><span class="p">():</span>
    -            <span class="n">handlers_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">handler_name</span><span class="p">]</span>
    -
    -        <span class="n">logger</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span> <span class="s2">&quot;handlers&quot;</span><span class="p">:</span> <span class="n">handlers_list</span><span class="p">,</span> <span class="s2">&quot;propagate&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span>
    -
    -        <span class="c1"># ADDING THE NEW LOGGER ENTRIES TO THE CONFIG DICT</span>
    -        <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;handlers&quot;</span><span class="p">][</span><span class="n">handler_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">file_handler</span>
    -        <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;loggers&quot;</span><span class="p">][</span><span class="n">module_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">logger</span>
    -        <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;root&quot;</span><span class="p">][</span><span class="s2">&quot;handlers&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">handler_name</span><span class="p">)</span>
    -
    -        <span class="k">if</span> <span class="n">training_log_path</span><span class="p">:</span>
    -            <span class="n">training_file_handler</span> <span class="o">=</span> <span class="p">{</span>
    -                <span class="s2">&quot;class&quot;</span><span class="p">:</span> <span class="s2">&quot;logging.handlers.RotatingFileHandler&quot;</span><span class="p">,</span>
    -                <span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span>
    -                <span class="s2">&quot;formatter&quot;</span><span class="p">:</span> <span class="s2">&quot;fileFormatter&quot;</span><span class="p">,</span>
    -                <span class="s2">&quot;filename&quot;</span><span class="p">:</span> <span class="n">training_log_path</span><span class="p">,</span>
    -                <span class="s2">&quot;maxBytes&quot;</span><span class="p">:</span> <span class="n">max_bytes</span><span class="p">,</span>
    -                <span class="s2">&quot;backupCount&quot;</span><span class="p">:</span> <span class="mi">20</span><span class="p">,</span>
    -                <span class="s2">&quot;encoding&quot;</span><span class="p">:</span> <span class="s2">&quot;utf8&quot;</span><span class="p">,</span>
    -            <span class="p">}</span>
    -
    -            <span class="c1"># ALL OF DECI_TRAINER MODULES LOGGERS PROPAGATE UP TO THE ROOT SO THE ADD TRAIN FILE HANDLER FOR THE ROOT.</span>
    -            <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;handlers&quot;</span><span class="p">][</span><span class="s2">&quot;training&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">training_file_handler</span>
    -            <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;root&quot;</span><span class="p">][</span><span class="s2">&quot;handlers&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot;training&quot;</span><span class="p">)</span>
    -
    -        <span class="k">return</span> <span class="n">config_dict</span></div></div>
    +        <span class="c1"># Add console handler</span>
    +        <span class="n">console_handler</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">StreamHandler</span><span class="p">()</span>
    +        <span class="n">console_handler</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">console_logging_level</span><span class="p">)</span>
    +        <span class="n">console_handler</span><span class="o">.</span><span class="n">setFormatter</span><span class="p">(</span>
    +            <span class="n">logging</span><span class="o">.</span><span class="n">Formatter</span><span class="p">(</span>
    +                <span class="s2">&quot;</span><span class="si">%(asctime)s</span><span class="s2"> </span><span class="si">%(levelname)s</span><span class="s2"> - </span><span class="si">%(filename)s</span><span class="s2"> - </span><span class="si">%(message)s</span><span class="s2">&quot;</span><span class="p">,</span>
    +                <span class="n">datefmt</span><span class="o">=</span><span class="s2">&quot;[%Y-%m-</span><span class="si">%d</span><span class="s2"> %H:%M:%S]&quot;</span><span class="p">,</span>
    +            <span class="p">)</span>
    +        <span class="p">)</span>
    +        <span class="n">manager</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">console_handler</span><span class="p">)</span>
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
    +
    +<div class="viewcode-block" id="AutoLoggerConfig.get_instance"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig.get_instance">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">get_instance</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
    +        <span class="k">global</span> <span class="n">_super_gradients_logger_config</span>
    +        <span class="k">if</span> <span class="n">_super_gradients_logger_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="n">_super_gradients_logger_config</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">()</span>
    +            <span class="n">_super_gradients_logger_config</span><span class="o">.</span><span class="n">_setup_default_logging</span><span class="p">()</span>
    +
    +        <span class="k">return</span> <span class="n">_super_gradients_logger_config</span></div>
    +
    +<div class="viewcode-block" id="AutoLoggerConfig.get_log_file_path"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig.get_log_file_path">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">get_log_file_path</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Return the current log file used to store log messages</span>
    +<span class="sd">        :return: Full path to log file</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="bp">self</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">get_instance</span><span class="p">()</span>
    +        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">filename</span></div>
    +
    +<div class="viewcode-block" id="AutoLoggerConfig.setup_logging"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig.setup_logging">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">setup_logging</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">copy_already_logged_messages</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">filemode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span cla
    +        <span class="bp">self</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">get_instance</span><span class="p">()</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_setup_logging</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">copy_already_logged_messages</span><span class="p">,</span> <span class="n">filemode</span><span class="p">,</span> <span class="n">log_level</span><span class="p">)</span></div></div>
    +
    +
    +<span class="n">_super_gradients_logger_config</span> <span class="o">=</span> <span class="kc">None</span>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -225,4 +241,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.common.auto_logging.console_logging &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../_static/jquery.js"></script>
    16. <script src="../../../../_static/underscore.js"></script>
    17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../_static/doctools.js"></script>
    19. <script src="../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.common.auto_logging.console_logging</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.common.auto_logging.console_logging</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">import</span> <span class="nn">os</span>
    86. <span class="kn">import</span> <span class="nn">sys</span>
    87. <span class="kn">from</span> <span class="nn">datetime</span> <span class="kn">import</span> <span class="n">datetime</span>
    88. <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
    89. <span class="kn">from</span> <span class="nn">io</span> <span class="kn">import</span> <span class="n">StringIO</span>
    90. <span class="kn">import</span> <span class="nn">atexit</span>
    91. <span class="kn">from</span> <span class="nn">threading</span> <span class="kn">import</span> <span class="n">Lock</span>
    92. <span class="kn">from</span> <span class="nn">super_gradients.common.environment.env_helpers</span> <span class="kn">import</span> <span class="n">multi_process_safe</span><span class="p">,</span> <span class="n">is_main_process</span>
    93. <span class="k">class</span> <span class="nc">BufferWriter</span><span class="p">:</span>
    94. <span class="sd">&quot;&quot;&quot;File writer buffer that opens a file only when flushing and under the condition that threshold buffersize was reached.&quot;&quot;&quot;</span>
    95. <span class="n">FILE_BUFFER_SIZE</span> <span class="o">=</span> <span class="mi">10_000</span> <span class="c1"># Number of chars to be buffered before writing the buffer on disk.</span>
    96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">buffer</span><span class="p">:</span> <span class="n">StringIO</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="n">Lock</span><span class="p">):</span>
    97. <span class="sd">&quot;&quot;&quot;</span>
    98. <span class="sd"> :param filename: Name of the file where to write the bugger</span>
    99. <span class="sd"> :param buffer: Buffer object</span>
    100. <span class="sd"> :param buffer_size: Number of chars to be buffered before writing the buffer on disk.</span>
    101. <span class="sd"> :param lock: Thread lock to prevent multiple threads to write at the same time</span>
    102. <span class="sd"> &quot;&quot;&quot;</span>
    103. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="n">buffer</span>
    104. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
    105. <span class="bp">self</span><span class="o">.</span><span class="n">buffer_size</span> <span class="o">=</span> <span class="n">buffer_size</span>
    106. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span> <span class="o">=</span> <span class="n">lock</span>
    107. <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    108. <span class="sd">&quot;&quot;&quot;Write to buffer (not on disk).&quot;&quot;&quot;</span>
    109. <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
    110. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    111. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_require_flush</span><span class="p">():</span>
    112. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    113. <span class="k">def</span> <span class="nf">flush</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">force</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    114. <span class="sd">&quot;&quot;&quot;Write the buffer on disk if relevant.&quot;&quot;&quot;</span>
    115. <span class="k">if</span> <span class="n">force</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_require_flush</span><span class="p">():</span>
    116. <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
    117. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    118. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
    119. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">getvalue</span><span class="p">())</span>
    120. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">truncate</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    121. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    122. <span class="k">def</span> <span class="nf">_require_flush</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
    123. <span class="sd">&quot;&quot;&quot;Indicate if a buffer is needed (i.e. if buffer size above threshold)&quot;&quot;&quot;</span>
    124. <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">getvalue</span><span class="p">())</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer_size</span>
    125. <span class="k">class</span> <span class="nc">StderrTee</span><span class="p">(</span><span class="n">BufferWriter</span><span class="p">):</span>
    126. <span class="sd">&quot;&quot;&quot;Duplicate the stderr stream to save it into a given file.&quot;&quot;&quot;</span>
    127. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">buffer</span><span class="p">:</span> <span class="n">StringIO</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="n">Lock</span><span class="p">):</span>
    128. <span class="sd">&quot;&quot;&quot;</span>
    129. <span class="sd"> :param filename: Name of the file where to write the bugger</span>
    130. <span class="sd"> :param buffer: Buffer object</span>
    131. <span class="sd"> :param buffer_size: Number of chars to be buffered before writing the buffer on disk.</span>
    132. <span class="sd"> :param lock: Thread lock to prevent multiple threads to write at the same time</span>
    133. <span class="sd"> &quot;&quot;&quot;</span>
    134. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
    135. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span>
    136. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="bp">self</span>
    137. <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    138. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span>
    139. <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
    140. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    141. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    142. <span class="k">def</span> <span class="fm">__getattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">):</span>
    143. <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="p">,</span> <span class="n">attr</span><span class="p">)</span>
    144. <span class="k">class</span> <span class="nc">StdoutTee</span><span class="p">(</span><span class="n">BufferWriter</span><span class="p">):</span>
    145. <span class="sd">&quot;&quot;&quot;Duplicate the stdout stream to save it into a given file.&quot;&quot;&quot;</span>
    146. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="n">Lock</span><span class="p">):</span>
    147. <span class="sd">&quot;&quot;&quot;</span>
    148. <span class="sd"> :param filename: Name of the file where to write the bugger</span>
    149. <span class="sd"> :param buffer: Buffer object</span>
    150. <span class="sd"> :param buffer_size: Number of chars to be buffered before writing the buffer on disk.</span>
    151. <span class="sd"> :param lock: Thread lock to prevent multiple threads to write at the same time</span>
    152. <span class="sd"> &quot;&quot;&quot;</span>
    153. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
    154. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span>
    155. <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="bp">self</span>
    156. <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    157. <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span>
    158. <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
    159. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    160. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    161. <span class="k">def</span> <span class="fm">__getattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">):</span>
    162. <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="p">,</span> <span class="n">attr</span><span class="p">)</span>
    163. <span class="k">def</span> <span class="nf">copy_file</span><span class="p">(</span><span class="n">src_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dest_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">copy_mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;w&quot;</span><span class="p">):</span>
    164. <span class="sd">&quot;&quot;&quot;Copy a file from source to destination. Also works when the destination folder does not exist.&quot;&quot;&quot;</span>
    165. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">dest_filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    166. <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">src_filename</span><span class="p">):</span>
    167. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">src_filename</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">src</span><span class="p">:</span>
    168. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">dest_filename</span><span class="p">,</span> <span class="n">copy_mode</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">dst</span><span class="p">:</span>
    169. <span class="n">dst</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
    170. <div class="viewcode-block" id="ConsoleSink"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink">[docs]</a><span class="k">class</span> <span class="nc">ConsoleSink</span><span class="p">:</span>
    171. <span class="sd">&quot;&quot;&quot;Singleton responsible to sink the console streams (stdout/stderr) into a file.&quot;&quot;&quot;</span>
    172. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    173. <span class="bp">self</span><span class="o">.</span><span class="n">_setup</span><span class="p">()</span>
    174. <span class="n">atexit</span><span class="o">.</span><span class="n">register</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_flush</span><span class="p">)</span> <span class="c1"># Flush at the end of the process</span>
    175. <span class="nd">@multi_process_safe</span>
    176. <span class="k">def</span> <span class="nf">_setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    177. <span class="sd">&quot;&quot;&quot;On instantiation, setup the default sink file.&quot;&quot;&quot;</span>
    178. <span class="n">filename</span> <span class="o">=</span> <span class="n">Path</span><span class="o">.</span><span class="n">home</span><span class="p">()</span> <span class="o">/</span> <span class="s2">&quot;sg_logs&quot;</span> <span class="o">/</span> <span class="s2">&quot;console.log&quot;</span>
    179. <span class="n">filename</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    180. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>
    181. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    182. <span class="n">buffer</span> <span class="o">=</span> <span class="n">StringIO</span><span class="p">()</span>
    183. <span class="n">lock</span> <span class="o">=</span> <span class="n">Lock</span><span class="p">()</span>
    184. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="n">StdoutTee</span><span class="p">(</span><span class="n">filename</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="o">=</span><span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="o">=</span><span class="n">BufferWriter</span><span class="o">.</span><span class="n">FILE_BUFFER_SIZE</span><span class="p">,</span> <span class="n">lock</span><span class="o">=</span><span class="n">lock</span><span class="p">)</span>
    185. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="n">StderrTee</span><span class="p">(</span><span class="n">filename</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="o">=</span><span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="o">=</span><span class="n">BufferWriter</span><span class="o">.</span><span class="n">FILE_BUFFER_SIZE</span><span class="p">,</span> <span class="n">lock</span><span class="o">=</span><span class="n">lock</span><span class="p">)</span>
    186. <span class="c1"># We don&#39;t want to rewrite this for subprocesses when using DDP.</span>
    187. <span class="k">if</span> <span class="n">is_main_process</span><span class="p">():</span>
    188. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;w&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
    189. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;============================================================</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
    190. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;New run started at </span><span class="si">{</span><span class="n">datetime</span><span class="o">.</span><span class="n">now</span><span class="p">()</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;%Y-%m-</span><span class="si">%d</span><span class="s2">.%H:%M:%S.</span><span class="si">%f</span><span class="s2">&quot;</span><span class="p">)</span><span class="si">}</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    191. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;sys.argv: &quot;</span><span class="si">{</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="p">)</span><span class="si">}</span><span class="s1">&quot;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    192. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;============================================================</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
    193. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;The console stream is logged into </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="si">}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
    194. <span class="nd">@multi_process_safe</span>
    195. <span class="k">def</span> <span class="nf">_set_location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    196. <span class="sd">&quot;&quot;&quot;Copy and redirect the sink file into another location.&quot;&quot;&quot;</span>
    197. <span class="bp">self</span><span class="o">.</span><span class="n">_flush</span><span class="p">()</span>
    198. <span class="n">prev_filename</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filename</span>
    199. <span class="n">copy_file</span><span class="p">(</span><span class="n">src_filename</span><span class="o">=</span><span class="n">prev_filename</span><span class="p">,</span> <span class="n">dest_filename</span><span class="o">=</span><span class="n">filename</span><span class="p">,</span> <span class="n">copy_mode</span><span class="o">=</span><span class="s2">&quot;a&quot;</span><span class="p">)</span>
    200. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
    201. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
    202. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
    203. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;The console stream is now moved to </span><span class="si">{</span><span class="n">filename</span><span class="si">}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
    204. <div class="viewcode-block" id="ConsoleSink.set_location"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink.set_location">[docs]</a> <span class="nd">@staticmethod</span>
    205. <span class="k">def</span> <span class="nf">set_location</span><span class="p">(</span><span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
    206. <span class="sd">&quot;&quot;&quot;Copy and redirect the sink file into another location.&quot;&quot;&quot;</span>
    207. <span class="n">_console_sink</span><span class="o">.</span><span class="n">_set_location</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span></div>
    208. <span class="nd">@multi_process_safe</span>
    209. <span class="k">def</span> <span class="nf">_flush</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    210. <span class="sd">&quot;&quot;&quot;Force the flush on stdout and stderr.&quot;&quot;&quot;</span>
    211. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">flush</span><span class="p">(</span><span class="n">force</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    212. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">flush</span><span class="p">(</span><span class="n">force</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    213. <div class="viewcode-block" id="ConsoleSink.flush"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink.flush">[docs]</a> <span class="nd">@staticmethod</span>
    214. <span class="k">def</span> <span class="nf">flush</span><span class="p">():</span>
    215. <span class="sd">&quot;&quot;&quot;Force the flush on stdout and stderr.&quot;&quot;&quot;</span>
    216. <span class="n">_console_sink</span><span class="o">.</span><span class="n">_flush</span><span class="p">()</span></div>
    217. <div class="viewcode-block" id="ConsoleSink.get_filename"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink.get_filename">[docs]</a> <span class="nd">@staticmethod</span>
    218. <span class="k">def</span> <span class="nf">get_filename</span><span class="p">():</span>
    219. <span class="sd">&quot;&quot;&quot;Get the filename of the sink.&quot;&quot;&quot;</span>
    220. <span class="k">return</span> <span class="n">_console_sink</span><span class="o">.</span><span class="n">filename</span></div></div>
    221. <span class="n">_console_sink</span> <span class="o">=</span> <span class="n">ConsoleSink</span><span class="p">()</span>
    222. </pre></div>
    223. </div>
    224. </div>
    225. <footer>
    226. <hr/>
    227. <div role="contentinfo">
    228. <p>&#169; Copyright 2021, SuperGradients team.</p>
    229. </div>
    230. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    231. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    232. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    233. </footer>
    234. </div>
    235. </div>
    236. </section>
    237. </div>
    238. <script>
    239. jQuery(function () {
    240. SphinxRtdTheme.Navigation.enable(true);
    241. });
    242. </script>
    243. </body>
    244. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.aws_connection.aws_connector &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.aws_connection.aws_connector &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -93,7 +95,7 @@
     <span class="kn">from</span> <span class="nn">botocore.exceptions</span> <span class="kn">import</span> <span class="n">ClientError</span><span class="p">,</span> <span class="n">ProfileNotFound</span>
     <span class="kn">from</span> <span class="nn">botocore.exceptions</span> <span class="kn">import</span> <span class="n">ClientError</span><span class="p">,</span> <span class="n">ProfileNotFound</span>
     
     
     
     
    -<div class="viewcode-block" id="AWSConnector"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector">[docs]</a><span class="k">class</span> <span class="nc">AWSConnector</span><span class="p">:</span>
    +<div class="viewcode-block" id="AWSConnector"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector">[docs]</a><span class="k">class</span> <span class="nc">AWSConnector</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    AWSConnector - Connects to AWS using Credentials File or IAM Role</span>
     <span class="sd">    AWSConnector - Connects to AWS using Credentials File or IAM Role</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    @@ -133,7 +135,7 @@
                         <span class="n">ex</span><span class="p">))</span>
                         <span class="n">ex</span><span class="p">))</span>
                 <span class="k">return</span> <span class="kc">None</span>
                 <span class="k">return</span> <span class="kc">None</span>
     
     
    -<div class="viewcode-block" id="AWSConnector.get_aws_session"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.get_aws_session">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="AWSConnector.get_aws_session"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.get_aws_session">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">get_aws_session</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="p">:</span>
         <span class="k">def</span> <span class="nf">get_aws_session</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        get_aws_session - Connects to AWS to retrieve an AWS Session</span>
     <span class="sd">        get_aws_session - Connects to AWS to retrieve an AWS Session</span>
    @@ -149,7 +151,7 @@
     
     
             <span class="k">return</span> <span class="n">aws_session</span></div>
             <span class="k">return</span> <span class="n">aws_session</span></div>
     
     
    -<div class="viewcode-block" id="AWSConnector.get_aws_client_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.get_aws_client_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="AWSConnector.get_aws_client_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.get_aws_client_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">get_aws_client_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">client</span><
         <span class="k">def</span> <span class="nf">get_aws_client_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">client</span><
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        get_aws_client_for_service_name - Connects to AWS to retrieve the relevant Client</span>
     <span class="sd">        get_aws_client_for_service_name - Connects to AWS to retrieve the relevant Client</span>
    @@ -166,7 +168,7 @@
     
     
             <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">client</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
             <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">client</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
     
     
    -<div class="viewcode-block" id="AWSConnector.get_aws_resource_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.get_aws_resource_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="AWSConnector.get_aws_resource_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.get_aws_resource_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">get_aws_resource_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">resource</sp
         <span class="k">def</span> <span class="nf">get_aws_resource_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">resource</sp
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Connects to AWS to retrieve the relevant Resource (More functionality then Client)</span>
     <span class="sd">        Connects to AWS to retrieve the relevant Resource (More functionality then Client)</span>
    @@ -183,7 +185,7 @@
     
     
             <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">resource</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
             <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">resource</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
     
     
    -<div class="viewcode-block" id="AWSConnector.is_client_error"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.is_client_error">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="AWSConnector.is_client_error"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.is_client_error">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">is_client_error</span><span class="p">(</span><span class="n">code</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">is_client_error</span><span class="p">(</span><span class="n">code</span><span class="p">):</span>
             <span class="n">e</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
             <span class="n">e</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">e</span><span class="p">,</span> <span class="n">ClientError</span><span class="p">)</span> <span class="ow">and</span> <span class="n">e</span><span class="o">.</span><span class="n">response</span><span class="p">[</span><span class="s2">&quot;Error&quot;</span><span class="p">][</span><span class="s2">&quot;Code&quot;</span><span class="p">]</span> <span class="o">==</span> <span class
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">e</span><span class="p">,</span> <span class="n">ClientError</span><span class="p">)</span> <span class="ow">and</span> <span class="n">e</span><span class="o">.</span><span class="n">response</span><span class="p">[</span><span class="s2">&quot;Error&quot;</span><span class="p">][</span><span class="s2">&quot;Code&quot;</span><span class="p">]</span> <span class="o">==</span> <span class
    @@ -218,4 +220,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.common.aws_connection.aws_secrets_manager_connector &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.common.aws_connection.aws_secrets_manager_connector</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.common.aws_connection.aws_secrets_manager_connector</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">json</span>
    84. <span class="kn">import</span> <span class="nn">logging</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">AWSConnector</span>
    86. <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">explicit_params_validation</span>
    87. <div class="viewcode-block" id="AWSSecretsManagerConnector"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.aws_connection.aws_secrets_manager_connector.AWSSecretsManagerConnector">[docs]</a><span class="k">class</span> <span class="nc">AWSSecretsManagerConnector</span><span class="p">:</span>
    88. <span class="sd">&quot;&quot;&quot;</span>
    89. <span class="sd"> AWSSecretsManagerConnector - This class handles the AWS Secrets Manager connection</span>
    90. <span class="sd"> &quot;&quot;&quot;</span>
    91. <span class="vm">__slots__</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># Making the class immutable for runtime safety</span>
    92. <span class="n">current_environment_client</span> <span class="o">=</span> <span class="kc">None</span>
    93. <span class="n">DECI_ENVIRONMENTS</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;research&#39;</span><span class="p">,</span> <span class="s1">&#39;development&#39;</span><span class="p">,</span> <span class="s1">&#39;staging&#39;</span><span class="p">,</span> <span class="s1">&#39;production&#39;</span><span class="p">]</span>
    94. <span class="nd">@staticmethod</span>
    95. <span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s1">&#39;NoneOrEmpty&#39;</span><span class="p">)</span>
    96. <span class="k">def</span> <span class="nf">get_secret_value_for_secret_key</span><span class="p">(</span><span class="n">aws_env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_key</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
    97. <span class="sd">&quot;&quot;&quot;</span>
    98. <span class="sd"> get_secret_value_for_secret_key - Gets a Secret Value from AWS Secrets Manager for the Provided Key</span>
    99. <span class="sd"> :param aws_env: The environment to get the secret for</span>
    100. <span class="sd"> :param secret_name: The Secret Name stored in Secrets Manager</span>
    101. <span class="sd"> :param secret_key: The Secret Key To retrieve it&#39;s value from AWS</span>
    102. <span class="sd"> :return: str: The Secret Value</span>
    103. <span class="sd"> &quot;&quot;&quot;</span>
    104. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
    105. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
    106. <span class="n">secret_key</span> <span class="o">=</span> <span class="n">secret_key</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span>
    107. <span class="n">aws_secrets_dict</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">__get_secrets_manager_dict_for_secret_name</span><span class="p">(</span>
    108. <span class="n">aws_env</span><span class="o">=</span><span class="n">aws_env</span><span class="p">,</span> <span class="n">secret_name</span><span class="o">=</span><span class="n">secret_name</span><span class="p">)</span>
    109. <span class="n">secret_key</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">aws_env</span><span class="o">.</span><span class="n">upper</span><span class="p">(),</span> <span class="n">secret_key</span><span class="p">])</span>
    110. <span class="k">if</span> <span class="n">secret_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">aws_secrets_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    111. <span class="n">error</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[</span><span class="si">{</span><span class="n">current_class_name</span><span class="si">}</span><span class="s1">] - Secret Key (</span><span class="si">{</span><span class="n">secret_key</span><span class="si">}</span><span class="s1">) not Found in AWS Secret: &#39;</span> <span class="o">+</span> <span class="n">secret_name</span>
    112. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
    113. <span class="k">raise</span> <span class="ne">EnvironmentError</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
    114. <span class="k">else</span><span class="p">:</span>
    115. <span class="k">return</span> <span class="n">aws_secrets_dict</span><span class="p">[</span><span class="n">secret_key</span><span class="p">]</span>
    116. <span class="nd">@staticmethod</span>
    117. <span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s1">&#39;NoneOrEmpty&#39;</span><span class="p">)</span>
    118. <span class="k">def</span> <span class="nf">get_secret_values_dict_for_secret_key_properties</span><span class="p">(</span><span class="n">env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_key</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    119. <span class="n">db_properties_set</span><span class="p">:</span> <span class="nb">set</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
    120. <span class="sd">&quot;&quot;&quot;</span>
    121. <span class="sd"> get_config_dict - Returns the config dict of the properties from the properties dict</span>
    122. <span class="sd"> :param env: The environment to open the dict for</span>
    123. <span class="sd"> :param secret_key: The Secret Key</span>
    124. <span class="sd"> :param secret_name: The Secret to Retrieve to from AWS secrets manager (usually project name)</span>
    125. <span class="sd"> :param db_properties_set: The set of the properties to get secrets values for</span>
    126. <span class="sd"> :return: dict The secrets dict for the requested property</span>
    127. <span class="sd"> &quot;&quot;&quot;</span>
    128. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
    129. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
    130. <span class="n">aws_secrets_dict</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">__get_secrets_manager_dict_for_secret_name</span><span class="p">(</span>
    131. <span class="n">aws_env</span><span class="o">=</span><span class="n">env</span><span class="p">,</span> <span class="n">secret_name</span><span class="o">=</span><span class="n">secret_name</span><span class="p">)</span>
    132. <span class="n">aws_env_safe_secrets</span> <span class="o">=</span> <span class="p">{}</span>
    133. <span class="c1"># FILL THE DICT VALUES FROM THE AWS SECRETS RESPONSE</span>
    134. <span class="k">if</span> <span class="n">db_properties_set</span><span class="p">:</span>
    135. <span class="k">for</span> <span class="n">secret_key_property</span> <span class="ow">in</span> <span class="n">db_properties_set</span><span class="p">:</span>
    136. <span class="n">secret_key_to_retrieve</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">(),</span> <span class="n">secret_key</span><span class="p">,</span> <span class="n">secret_key_property</span><span class="p">])</span>
    137. <span class="k">if</span> <span class="n">secret_key_to_retrieve</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">aws_secrets_dict</span><span class="p">:</span>
    138. <span class="n">error</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[</span><span class="si">{</span><span class="n">current_class_name</span><span class="si">}</span><span class="s1">] - Error retrieving data from AWS Secrets Manager for Secret Key &quot;</span><span class="si">{</span><span class="n">secret_name</span><span class="si">}</span><span class="s1">&quot;: &#39;</span> \
    139. <span class="sa">f</span><span class="s1">&#39;The secret property &quot;</span><span class="si">{</span><span class="n">secret_key_property</span><span class="si">}</span><span class="s1">&quot; Does Not Exist&#39;</span>
    140. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
    141. <span class="k">raise</span> <span class="ne">EnvironmentError</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
    142. <span class="k">else</span><span class="p">:</span>
    143. <span class="n">env_stripped_key_name</span> <span class="o">=</span> <span class="n">secret_key_to_retrieve</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">())</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">)</span>
    144. <span class="n">aws_env_safe_secrets</span><span class="p">[</span><span class="n">env_stripped_key_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">aws_secrets_dict</span><span class="p">[</span><span class="n">secret_key_to_retrieve</span><span class="p">]</span>
    145. <span class="k">else</span><span class="p">:</span>
    146. <span class="c1"># &quot;db_properties_set&quot; is not specified - validating and returning all the secret keys and values for</span>
    147. <span class="c1"># the secret name.</span>
    148. <span class="k">for</span> <span class="n">secret_key_name</span><span class="p">,</span> <span class="n">secret_value</span> <span class="ow">in</span> <span class="n">aws_secrets_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    149. <span class="n">secret_key_to_retrieve</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">(),</span> <span class="n">secret_key</span><span class="p">])</span>
    150. <span class="k">assert</span> <span class="n">secret_key_name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span>
    151. <span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">()),</span> <span class="sa">f</span><span class="s1">&#39;The secret key property &quot;</span><span class="si">{</span><span class="n">secret_key_name</span><span class="si">}</span><span class="s1">&quot;, found in secret named </span><span class="si">{</span><span class="n">secret_name</span><span class="si">}</span><span class="s1">, is not following the convention of &#39;</span> \
    152. <span class="sa">f</span><span class="s1">&#39;environment prefix. please add the environment prefix &quot;</span><span class="si">{</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span><span class="si">}</span><span class="s1">&quot; to property &quot;</span><span class="si">{</span><span class="n">secret_key_name</span><span class="si">}</span><span class="s1">&quot;&#39;</span>
    153. <span class="k">if</span> <span class="n">secret_key_name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">secret_key_to_retrieve</span><span class="p">):</span>
    154. <span class="n">env_stripped_key_name</span> <span class="o">=</span> <span class="n">secret_key_name</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">())</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">)</span>
    155. <span class="n">aws_env_safe_secrets</span><span class="p">[</span><span class="n">env_stripped_key_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">secret_value</span>
    156. <span class="k">return</span> <span class="n">aws_env_safe_secrets</span>
    157. <span class="nd">@staticmethod</span>
    158. <span class="k">def</span> <span class="nf">__get_secrets_manager_dict_for_secret_name</span><span class="p">(</span><span class="n">aws_env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
    159. <span class="sd">&quot;&quot;&quot;</span>
    160. <span class="sd"> __get_secrets_manager_dict_for_secret_name</span>
    161. <span class="sd"> :param aws_env: The environment to open the dict for</span>
    162. <span class="sd"> :param secret_name: The Secret to Retrieve to from AWS secrets manager (usually project name)</span>
    163. <span class="sd"> :return: python Dictionary with the key/value pairs stored in AWS Secrets Manager</span>
    164. <span class="sd"> &quot;&quot;&quot;</span>
    165. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
    166. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
    167. <span class="n">secrets_path</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">__get_secrets_path_from_secret_name</span><span class="p">(</span><span class="n">aws_env</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">)</span>
    168. <span class="k">try</span><span class="p">:</span>
    169. <span class="k">if</span> <span class="ow">not</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">current_environment_client</span><span class="p">:</span>
    170. <span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s1">&#39;Initializing a new secrets manager client...&#39;</span><span class="p">)</span>
    171. <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">current_environment_client</span> <span class="o">=</span> <span class="n">AWSConnector</span><span class="o">.</span><span class="n">get_aws_client_for_service_name</span><span class="p">(</span>
    172. <span class="n">profile_name</span><span class="o">=</span><span class="n">aws_env</span><span class="p">,</span>
    173. <span class="n">service_name</span><span class="o">=</span><span class="s1">&#39;secretsmanager&#39;</span><span class="p">)</span>
    174. <span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Fetching the secret &quot;</span><span class="si">{</span><span class="n">secret_name</span><span class="si">}</span><span class="s1">&quot; in env &quot;</span><span class="si">{</span><span class="n">aws_env</span><span class="si">}</span><span class="s1">&quot;&#39;</span><span class="p">)</span>
    175. <span class="n">aws_secrets</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">current_environment_client</span><span class="o">.</span><span class="n">get_secret_value</span><span class="p">(</span><span class="n">SecretId</span><span class="o">=</span><span class="n">secrets_path</span><span class="p">)</span>
    176. <span class="n">aws_secrets_dict</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">aws_secrets</span><span class="p">[</span><span class="s1">&#39;SecretString&#39;</span><span class="p">])</span>
    177. <span class="k">return</span> <span class="n">aws_secrets_dict</span>
    178. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
    179. <span class="n">error</span> <span class="o">=</span> <span class="s1">&#39;[&#39;</span> <span class="o">+</span> <span class="n">current_class_name</span> <span class="o">+</span> <span class="s1">&#39;] - Caught Exception while trying to connect to aws to get credentials from secrets manager: &#39;</span> <span class="o">+</span> <span class="s1">&#39;&quot;&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span>
    180. <span class="n">ex</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;&quot;&#39;</span> <span class="o">+</span> <span class="s1">&#39; for &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">secrets_path</span><span class="p">)</span>
    181. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
    182. <span class="k">raise</span> <span class="ne">EnvironmentError</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
    183. <span class="nd">@staticmethod</span>
    184. <span class="k">def</span> <span class="nf">__get_secrets_path_from_secret_name</span><span class="p">(</span><span class="n">aws_env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
    185. <span class="sd">&quot;&quot;&quot;</span>
    186. <span class="sd"> __get_secrets_path_from_secret_name - Extracts the full secret path based on the Environment</span>
    187. <span class="sd"> :param aws_env: Env</span>
    188. <span class="sd"> :param secret_name: Secret Name</span>
    189. <span class="sd"> :return: str: The full secret path</span>
    190. <span class="sd"> &quot;&quot;&quot;</span>
    191. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
    192. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
    193. <span class="c1"># Checking for lowercase exact match, in order to prevent any implicit usage of the environments.</span>
    194. <span class="k">if</span> <span class="n">aws_env</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">DECI_ENVIRONMENTS</span><span class="p">:</span>
    195. <span class="n">logger</span><span class="o">.</span><span class="n">critical</span><span class="p">(</span><span class="s1">&#39;[&#39;</span> <span class="o">+</span> <span class="n">current_class_name</span> <span class="o">+</span> <span class="s1">&#39; ] - wrong environment param... Exiting&#39;</span><span class="p">)</span>
    196. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;[&#39;</span> <span class="o">+</span> <span class="n">current_class_name</span> <span class="o">+</span> <span class="s1">&#39;] - wrong environment param&#39;</span><span class="p">)</span>
    197. <span class="n">secrets_path</span> <span class="o">=</span> <span class="s1">&#39;/&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">aws_env</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">])</span>
    198. <span class="k">return</span> <span class="n">secrets_path</span></div>
    199. </pre></div>
    200. </div>
    201. </div>
    202. <footer>
    203. <hr/>
    204. <div role="contentinfo">
    205. <p>&#169; Copyright 2021, SuperGradients team.</p>
    206. </div>
    207. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    208. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    209. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    210. </footer>
    211. </div>
    212. </div>
    213. </section>
    214. </div>
    215. <script>
    216. jQuery(function () {
    217. SphinxRtdTheme.Navigation.enable(true);
    218. });
    219. </script>
    220. </body>
    221. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.common.crash_handler.crash_handler &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../_static/jquery.js"></script>
    16. <script src="../../../../_static/underscore.js"></script>
    17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../_static/doctools.js"></script>
    19. <script src="../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.common.crash_handler.crash_handler</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.common.crash_handler.crash_handler</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">import</span> <span class="nn">sys</span>
    86. <span class="kn">from</span> <span class="nn">super_gradients.common.crash_handler.crash_tips_setup</span> <span class="kn">import</span> <span class="n">setup_crash_tips</span>
    87. <span class="kn">from</span> <span class="nn">super_gradients.common.crash_handler.exception_monitoring_setup</span> <span class="kn">import</span> <span class="n">setup_pro_user_monitoring</span>
    88. <span class="kn">from</span> <span class="nn">super_gradients.common.crash_handler.exception</span> <span class="kn">import</span> <span class="n">register_exceptions</span>
    89. <div class="viewcode-block" id="setup_crash_handler"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.setup_crash_handler">[docs]</a><span class="k">def</span> <span class="nf">setup_crash_handler</span><span class="p">():</span>
    90. <span class="sd">&quot;&quot;&quot;Setup the environment to handle crashes, with crash tips and more.&quot;&quot;&quot;</span>
    91. <span class="n">is_setup_crash_tips</span> <span class="o">=</span> <span class="n">setup_crash_tips</span><span class="p">()</span>
    92. <span class="n">is_setup_pro_user_monitoring</span> <span class="o">=</span> <span class="n">setup_pro_user_monitoring</span><span class="p">()</span>
    93. <span class="k">if</span> <span class="n">is_setup_crash_tips</span> <span class="ow">or</span> <span class="n">is_setup_pro_user_monitoring</span><span class="p">:</span> <span class="c1"># We don&#39;t want to wrap sys.excepthook when not required</span>
    94. <span class="n">sys</span><span class="o">.</span><span class="n">excepthook</span> <span class="o">=</span> <span class="n">register_exceptions</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">excepthook</span><span class="p">)</span></div>
    95. </pre></div>
    96. </div>
    97. </div>
    98. <footer>
    99. <hr/>
    100. <div role="contentinfo">
    101. <p>&#169; Copyright 2021, SuperGradients team.</p>
    102. </div>
    103. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    104. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    105. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    106. </footer>
    107. </div>
    108. </div>
    109. </section>
    110. </div>
    111. <script>
    112. jQuery(function () {
    113. SphinxRtdTheme.Navigation.enable(true);
    114. });
    115. </script>
    116. </body>
    117. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.data_connection.s3_connector &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.data_connection.s3_connector &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -98,11 +100,11 @@
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
     
     
     
     
    -<div class="viewcode-block" id="KeyNotExistInBucketError"><a class="viewcode-back" href="../../../../super_gradients.common.data_connection.html#super_gradients.common.KeyNotExistInBucketError">[docs]</a><span class="k">class</span> <span class="nc">KeyNotExistInBucketError</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    -    <span class="k">pass</span></div>
    +<span class="k">class</span> <span class="nc">KeyNotExistInBucketError</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    +    <span class="k">pass</span>
     
     
     
     
    -<div class="viewcode-block" id="S3Connector"><a class="viewcode-back" href="../../../../super_gradients.common.data_connection.html#super_gradients.common.S3Connector">[docs]</a><span class="k">class</span> <span class="nc">S3Connector</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
    +<div class="viewcode-block" id="S3Connector"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.S3Connector">[docs]</a><span class="k">class</span> <span class="nc">S3Connector</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    S3Connector - S3 Connection Manager</span>
     <span class="sd">    S3Connector - S3 Connection Manager</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    @@ -458,7 +460,7 @@
                                                              <span class="n">ExpiresIn</span><span class="o">=</span><span class="n">expiration</span><span class="p">)</span>
                                                              <span class="n">ExpiresIn</span><span class="o">=</span><span class="n">expiration</span><span class="p">)</span>
             <span class="k">return</span> <span class="n">response</span>
             <span class="k">return</span> <span class="n">response</span>
     
     
    -<div class="viewcode-block" id="S3Connector.convert_content_length_to_mb"><a class="viewcode-back" href="../../../../super_gradients.common.data_connection.html#super_gradients.common.S3Connector.convert_content_length_to_mb">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="S3Connector.convert_content_length_to_mb"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.S3Connector.convert_content_length_to_mb">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">convert_content_length_to_mb</span><span class="p">(</span><span class="n">content_length</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">convert_content_length_to_mb</span><span class="p">(</span><span class="n">content_length</span><span class="p">):</span>
             <span class="k">return</span> <span class="nb">round</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">content_length</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1e+6</span><span class="p">)</span><span class="si">:</span><span class="s1">2f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">),</span> <span 
             <span class="k">return</span> <span class="nb">round</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">content_length</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1e+6</span><span class="p">)</span><span class="si">:</span><span class="s1">2f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">),</span> <span 
     
     
    @@ -524,4 +526,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.data_interface.adnn_model_repository_data_interface &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.data_interface.adnn_model_repository_data_interface &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -92,11 +94,11 @@
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
     
     
     
     
    -<div class="viewcode-block" id="ModelCheckpointNotFoundException"><a class="viewcode-back" href="../../../../super_gradients.common.data_interface.html#super_gradients.common.ModelCheckpointNotFoundException">[docs]</a><span class="k">class</span> <span class="nc">ModelCheckpointNotFoundException</span><span class="p">(</span><span class="ne">RuntimeError</span><span class="p">):</span>
    -    <span class="k">pass</span></div>
    +<span class="k">class</span> <span class="nc">ModelCheckpointNotFoundException</span><span class="p">(</span><span class="ne">RuntimeError</span><span class="p">):</span>
    +    <span class="k">pass</span>
     
     
     
     
    -<div class="viewcode-block" id="ADNNModelRepositoryDataInterfaces"><a class="viewcode-back" href="../../../../super_gradients.common.data_interface.html#super_gradients.common.ADNNModelRepositoryDataInterfaces">[docs]</a><span class="k">class</span> <span class="nc">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
    +<div class="viewcode-block" id="ADNNModelRepositoryDataInterfaces"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.ADNNModelRepositoryDataInterfaces">[docs]</a><span class="k">class</span> <span class="nc">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    ResearchModelRepositoryDataInterface</span>
     <span class="sd">    ResearchModelRepositoryDataInterface</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    @@ -293,4 +295,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.data_interface.dataset_data_interface &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.data_interface.dataset_data_interface &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -92,7 +94,7 @@
     <span class="kn">import</span> <span class="nn">zipfile</span>
     <span class="kn">import</span> <span class="nn">zipfile</span>
     
     
     
     
    -<div class="viewcode-block" id="DatasetDataInterface"><a class="viewcode-back" href="../../../../super_gradients.common.data_interface.html#super_gradients.common.DatasetDataInterface">[docs]</a><span class="k">class</span> <span class="nc">DatasetDataInterface</span><span class="p">:</span>
    +<div class="viewcode-block" id="DatasetDataInterface"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.DatasetDataInterface">[docs]</a><span class="k">class</span> <span class="nc">DatasetDataInterface</span><span class="p">:</span>
     
     
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">data_connection_source</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;s3&#39;</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">data_connection_source</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;s3&#39;</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    @@ -167,4 +169,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.data_types.enum.deep_learning_task &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.data_types.enum.deep_learning_task &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -89,7 +91,7 @@
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
     
     
     
     
    -<div class="viewcode-block" id="DeepLearningTask"><a class="viewcode-back" href="../../../../../super_gradients.common.data_types.enum.html#super_gradients.common.DeepLearningTask">[docs]</a><span class="k">class</span> <span class="nc">DeepLearningTask</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
    +<div class="viewcode-block" id="DeepLearningTask"><a class="viewcode-back" href="../../../../../super_gradients.common.html#super_gradients.common.DeepLearningTask">[docs]</a><span class="k">class</span> <span class="nc">DeepLearningTask</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
         <span class="n">CLASSIFICATION</span> <span class="o">=</span> <span class="s1">&#39;classification&#39;</span>
         <span class="n">CLASSIFICATION</span> <span class="o">=</span> <span class="s1">&#39;classification&#39;</span>
         <span class="n">SEMANTIC_SEGMENTATION</span> <span class="o">=</span> <span class="s1">&#39;semantic_segmentation&#39;</span>
         <span class="n">SEMANTIC_SEGMENTATION</span> <span class="o">=</span> <span class="s1">&#39;semantic_segmentation&#39;</span>
         <span class="n">OBJECT_DETECTION</span> <span class="o">=</span> <span class="s1">&#39;object_detection&#39;</span>
         <span class="n">OBJECT_DETECTION</span> <span class="o">=</span> <span class="s1">&#39;object_detection&#39;</span>
    @@ -126,4 +128,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.data_types.enum.evaluation_type &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.data_types.enum.evaluation_type &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -89,11 +91,11 @@
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
     
     
     
     
    -<div class="viewcode-block" id="EvaluationType"><a class="viewcode-back" href="../../../../../super_gradients.common.data_types.enum.html#super_gradients.common.EvaluationType">[docs]</a><span class="k">class</span> <span class="nc">EvaluationType</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
    +<div class="viewcode-block" id="EvaluationType"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.EvaluationType">[docs]</a><span class="k">class</span> <span class="nc">EvaluationType</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    EvaluationType</span>
     <span class="sd">    EvaluationType</span>
     
     
    -<span class="sd">    Passed to SgModel.evaluate(..), and controls which phase callbacks should be triggered (if at all).</span>
    +<span class="sd">    Passed to Trainer.evaluate(..), and controls which phase callbacks should be triggered (if at all).</span>
     
     
     <span class="sd">        Attributes:</span>
     <span class="sd">        Attributes:</span>
     <span class="sd">            TEST</span>
     <span class="sd">            TEST</span>
    @@ -131,4 +133,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.data_types.enum.multi_gpu_mode &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.data_types.enum.multi_gpu_mode &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -87,9 +89,10 @@
                  
                  
       <h1>Source code for super_gradients.common.data_types.enum.multi_gpu_mode</h1><div class="highlight"><pre>
       <h1>Source code for super_gradients.common.data_types.enum.multi_gpu_mode</h1><div class="highlight"><pre>
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
    +<span class="kn">import</span> <span class="nn">stringcase</span>
     
     
     
     
    -<div class="viewcode-block" id="MultiGPUMode"><a class="viewcode-back" href="../../../../../super_gradients.training.sg_model.html#super_gradients.common.MultiGPUMode">[docs]</a><span class="k">class</span> <span class="nc">MultiGPUMode</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
    +<div class="viewcode-block" id="MultiGPUMode"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.MultiGPUMode">[docs]</a><span class="k">class</span> <span class="nc">MultiGPUMode</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    MultiGPUMode</span>
     <span class="sd">    MultiGPUMode</span>
     
     
    @@ -98,10 +101,26 @@
     <span class="sd">            DATA_PARALLEL             - Multiple GPUs, Synchronous</span>
     <span class="sd">            DATA_PARALLEL             - Multiple GPUs, Synchronous</span>
     <span class="sd">            DISTRIBUTED_DATA_PARALLEL - Multiple GPUs, Asynchronous</span>
     <span class="sd">            DISTRIBUTED_DATA_PARALLEL - Multiple GPUs, Asynchronous</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    -    <span class="n">OFF</span> <span class="o">=</span> <span class="s1">&#39;Off&#39;</span>
    -    <span class="n">DATA_PARALLEL</span> <span class="o">=</span> <span class="s1">&#39;DP&#39;</span>
    -    <span class="n">DISTRIBUTED_DATA_PARALLEL</span> <span class="o">=</span> <span class="s1">&#39;DDP&#39;</span>
    -    <span class="n">AUTO</span> <span class="o">=</span> <span class="s2">&quot;AUTO&quot;</span></div>
    +
    +    <span class="n">OFF</span> <span class="o">=</span> <span class="s2">&quot;Off&quot;</span>
    +    <span class="n">DATA_PARALLEL</span> <span class="o">=</span> <span class="s2">&quot;DP&quot;</span>
    +    <span class="n">DISTRIBUTED_DATA_PARALLEL</span> <span class="o">=</span> <span class="s2">&quot;DDP&quot;</span>
    +    <span class="n">AUTO</span> <span class="o">=</span> <span class="s2">&quot;AUTO&quot;</span>
    +
    +<div class="viewcode-block" id="MultiGPUMode.dict"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.MultiGPUMode.dict">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        return dictionary mapping from the mode name (in call string cases) to the enum value</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">out_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    +        <span class="k">for</span> <span class="n">mode</span> <span class="ow">in</span> <span class="n">MultiGPUMode</span><span class="p">:</span>
    +            <span class="n">out_dict</span><span class="p">[</span><span class="n">mode</span><span class="o">.</span><span class="n">value</span><span class="p">]</span> <span class="o">=</span> <span class="n">mode</span>
    +            <span class="n">out_dict</span><span class="p">[</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">mode</span>
    +            <span class="n">out_dict</span><span class="p">[</span><span class="n">stringcase</span><span class="o">.</span><span class="n">capitalcase</span><span class="p">(</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">)]</span> <span class="o">=</span> <span class="n">mode</span>
    +            <span class="n">out_dict</span><span class="p">[</span><span class="n">stringcase</span><span class="o">.</span><span class="n">camelcase</span><span class="p">(</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">)]</span> <span class="o">=</span> <span class="n">mode</span>
    +            <span class="n">out_dict</span><span class="p">[</span><span class="n">stringcase</span><span class="o">.</span><span class="n">lowercase</span><span class="p">(</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">)]</span> <span class="o">=</span> <span class="n">mode</span>
    +        <span class="n">out_dict</span><span class="p">[</span><span class="kc">False</span><span class="p">]</span> <span class="o">=</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">OFF</span>
    +        <span class="k">return</span> <span class="n">out_dict</span></div></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -131,4 +150,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.data_types.enum.strict_load &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.data_types.enum.strict_load &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -89,18 +91,19 @@
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
     <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
     
     
     
     
    -<div class="viewcode-block" id="StrictLoad"><a class="viewcode-back" href="../../../../../super_gradients.training.sg_model.html#super_gradients.common.StrictLoad">[docs]</a><span class="k">class</span> <span class="nc">StrictLoad</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
    +<div class="viewcode-block" id="StrictLoad"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.StrictLoad">[docs]</a><span class="k">class</span> <span class="nc">StrictLoad</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Wrapper for adding more functionality to torch&#39;s strict_load parameter in load_state_dict().</span>
     <span class="sd">    Wrapper for adding more functionality to torch&#39;s strict_load parameter in load_state_dict().</span>
    -<span class="sd">        Attributes:</span>
    -<span class="sd">            OFF              - Native torch &quot;strict_load = off&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
    -<span class="sd">            ON               - Native torch &quot;strict_load = on&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
    -<span class="sd">            NO_KEY_MATCHING  - Allows the usage of SuperGradient&#39;s adapt_checkpoint function, which loads a checkpoint by matching each</span>
    -<span class="sd">                               layer&#39;s shapes (and bypasses the strict matching of the names of each layer (ie: disregards the state_dict key matching)).</span>
    +<span class="sd">    Attributes:</span>
    +<span class="sd">        OFF              - Native torch &quot;strict_load = off&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
    +<span class="sd">        ON               - Native torch &quot;strict_load = on&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
    +<span class="sd">        NO_KEY_MATCHING  - Allows the usage of SuperGradient&#39;s adapt_checkpoint function, which loads a checkpoint by matching each</span>
    +<span class="sd">                           layer&#39;s shapes (and bypasses the strict matching of the names of each layer (ie: disregards the state_dict key matching)).</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    +
         <span class="n">OFF</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="n">OFF</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="n">ON</span> <span class="o">=</span> <span class="kc">True</span>
         <span class="n">ON</span> <span class="o">=</span> <span class="kc">True</span>
    -    <span class="n">NO_KEY_MATCHING</span> <span class="o">=</span> <span class="s1">&#39;no_key_matching&#39;</span></div>
    +    <span class="n">NO_KEY_MATCHING</span> <span class="o">=</span> <span class="s2">&quot;no_key_matching&quot;</span></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -130,4 +133,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.common.data_types.enum.upsample_mode &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../../_static/jquery.js"></script>
    16. <script src="../../../../../_static/underscore.js"></script>
    17. <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../../_static/doctools.js"></script>
    19. <script src="../../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.common.data_types.enum.upsample_mode</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.common.data_types.enum.upsample_mode</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
    86. <div class="viewcode-block" id="UpsampleMode"><a class="viewcode-back" href="../../../../../super_gradients.common.html#super_gradients.common.UpsampleMode">[docs]</a><span class="k">class</span> <span class="nc">UpsampleMode</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
    87. <span class="n">NEAREST</span> <span class="o">=</span> <span class="s2">&quot;nearest&quot;</span>
    88. <span class="n">BILINEAR</span> <span class="o">=</span> <span class="s2">&quot;bilinear&quot;</span>
    89. <span class="n">BICUBIC</span> <span class="o">=</span> <span class="s2">&quot;bicubic&quot;</span>
    90. <span class="n">SNPE_BILINEAR</span> <span class="o">=</span> <span class="s2">&quot;snpe_bilinear&quot;</span></div>
    91. </pre></div>
    92. </div>
    93. </div>
    94. <footer>
    95. <hr/>
    96. <div role="contentinfo">
    97. <p>&#169; Copyright 2021, SuperGradients team.</p>
    98. </div>
    99. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    100. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    101. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    102. </footer>
    103. </div>
    104. </div>
    105. </section>
    106. </div>
    107. <script>
    108. jQuery(function () {
    109. SphinxRtdTheme.Navigation.enable(true);
    110. });
    111. </script>
    112. </body>
    113. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.common.decorators.deci_logger &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.common.decorators.deci_logger</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.common.decorators.deci_logger</h1><div class="highlight"><pre>
    83. <div class="viewcode-block" id="deci_func_logger"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.decorators.deci_logger.deci_func_logger">[docs]</a><span></span><span class="k">def</span> <span class="nf">deci_func_logger</span><span class="p">(</span><span class="n">_func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;abstract_decorator&#39;</span><span class="p">):</span>
    84. <span class="sd">&quot;&quot;&quot;</span>
    85. <span class="sd"> This decorator is used to wrap our functions with logs.</span>
    86. <span class="sd"> It will log every enter and exit of the functon with the equivalent parameters as extras.</span>
    87. <span class="sd"> It will also log exceptions that raises in the function.</span>
    88. <span class="sd"> It will also log the exception time of the function.</span>
    89. <span class="sd"> How it works:`</span>
    90. <span class="sd"> First it will check if the decorator called with name keyword.</span>
    91. <span class="sd"> If so it will return a new decorator that its logger is the name parameter.</span>
    92. <span class="sd"> If not it will return a new decorator that its logger is the wrapped function name.</span>
    93. <span class="sd"> Then the return decorator will return a new function that warps the original function with the new logs.</span>
    94. <span class="sd"> For further understanding advise real-python &quot;fancy decorators documentation&quot;</span>
    95. <span class="sd"> Args:</span>
    96. <span class="sd"> _func (): used when called without name specify. dont pass it directly</span>
    97. <span class="sd"> name (): The name of the logger to save logs by.</span>
    98. <span class="sd"> Returns:</span>
    99. <span class="sd"> a decorator that wraps function with logs logic.</span>
    100. <span class="sd"> &quot;&quot;&quot;</span>
    101. <span class="c1"># TODO: Not Working - Breaks the code, tests does not pass (s3 connector, platform...)</span>
    102. <span class="c1"># TODO: Fix problem with ExplicitParamValidation error (arguments not passed)</span>
    103. <span class="c1"># TODO: Run ALL test suite of deci2 (NOT circieCI test suite, but ALL the tests under tests folders)</span>
    104. <span class="c1"># TODO: Delete/Update all failing tests.</span>
    105. <span class="c1"># def deci_logger_decorator(fn):</span>
    106. <span class="c1">#</span>
    107. <span class="c1"># @functools.wraps(fn)</span>
    108. <span class="c1"># def wrapper_func(*args, **kwargs):</span>
    109. <span class="c1"># try:</span>
    110. <span class="c1">#</span>
    111. <span class="c1"># try:</span>
    112. <span class="c1"># logger.debug(f&quot;Start: {fn.__name__}&quot;, extra={&quot;args&quot;: args, &quot;kwargs&quot;: kwargs})</span>
    113. <span class="c1"># time1 = time.perf_counter()</span>
    114. <span class="c1"># except Exception:</span>
    115. <span class="c1"># # failed to write log - continue.</span>
    116. <span class="c1"># pass</span>
    117. <span class="c1">#</span>
    118. <span class="c1"># result = fn(*args, **kwargs)</span>
    119. <span class="c1">#</span>
    120. <span class="c1"># try:</span>
    121. <span class="c1"># time2 = time.perf_counter()</span>
    122. <span class="c1"># logger.debug(f&quot;End: {fn.__name__}&quot;,</span>
    123. <span class="c1"># extra={&#39;duration&#39;: (time2 - time1) * 1000.0, &#39;return_value&#39;: result})</span>
    124. <span class="c1"># except Exception:</span>
    125. <span class="c1"># # failed to write log - continue.</span>
    126. <span class="c1"># pass</span>
    127. <span class="c1">#</span>
    128. <span class="c1"># return result</span>
    129. <span class="c1">#</span>
    130. <span class="c1"># except Exception as ex:</span>
    131. <span class="c1"># # This exception was raised from inside the function call</span>
    132. <span class="c1"># logger.error(f&quot;Exception: {ex}&quot;, exc_info=ex)</span>
    133. <span class="c1"># raise ex</span>
    134. <span class="c1">#</span>
    135. <span class="c1"># return wrapper_func</span>
    136. <span class="c1"># if _func is None:</span>
    137. <span class="c1"># logger = get_logger(name)</span>
    138. <span class="c1"># return deci_logger_decorator</span>
    139. <span class="c1"># else:</span>
    140. <span class="c1"># logger = get_logger(_func.__name__)</span>
    141. <span class="c1"># return deci_logger_decorator(_func)</span>
    142. <span class="k">return</span> <span class="n">_func</span></div>
    143. <div class="viewcode-block" id="deci_class_logger"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.decorators.deci_logger.deci_class_logger">[docs]</a><span class="k">def</span> <span class="nf">deci_class_logger</span><span class="p">():</span>
    144. <span class="sd">&quot;&quot;&quot;</span>
    145. <span class="sd"> This decorator wraps every class method with deci_func_logger decorator.</span>
    146. <span class="sd"> It works by checking if class method is callable and if so it will set a new decorated method as the same method name.</span>
    147. <span class="sd"> &quot;&quot;&quot;</span>
    148. <span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
    149. <span class="c1"># TODO: Not Working - Breaks the code, tests does not pass (s3 connector, platform...)</span>
    150. <span class="c1"># TODO: Fix problem with ExplicitParamValidation error (arguments not passed)</span>
    151. <span class="c1"># TODO: Run ALL test suite of deci2 (NOT circieCI test suite, but ALL the tests under tests folders)</span>
    152. <span class="c1"># TODO: Delete/Update all failing tests.</span>
    153. <span class="c1"># for attr in cls.__dict__:</span>
    154. <span class="c1"># if callable(getattr(cls, attr)) and attr != &#39;__init__&#39;:</span>
    155. <span class="c1"># decorated_function = deci_func_logger(name=cls.__name__)(getattr(cls, attr))</span>
    156. <span class="c1"># if type(cls.__dict__[attr]) is staticmethod:</span>
    157. <span class="c1"># decorated_function = staticmethod(decorated_function)</span>
    158. <span class="c1"># setattr(cls, attr, decorated_function)</span>
    159. <span class="k">return</span> <span class="bp">cls</span>
    160. <span class="k">return</span> <span class="n">wrapper</span></div>
    161. </pre></div>
    162. </div>
    163. </div>
    164. <footer>
    165. <hr/>
    166. <div role="contentinfo">
    167. <p>&#169; Copyright 2021, SuperGradients team.</p>
    168. </div>
    169. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    170. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    171. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    172. </footer>
    173. </div>
    174. </div>
    175. </section>
    176. </div>
    177. <script>
    178. jQuery(function () {
    179. SphinxRtdTheme.Navigation.enable(true);
    180. });
    181. </script>
    182. </body>
    183. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.decorators.explicit_params_validator &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.decorators.explicit_params_validator &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -160,7 +162,7 @@
     
     
     
     
     <span class="c1"># WRAPS THE RETRY DECORATOR CLASS TO ENABLE CALLING WITHOUT PARAMS</span>
     <span class="c1"># WRAPS THE RETRY DECORATOR CLASS TO ENABLE CALLING WITHOUT PARAMS</span>
    -<div class="viewcode-block" id="explicit_params_validation"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.explicit_params_validation">[docs]</a><span class="k">def</span> <span class="nf">explicit_params_validation</span><span class="p">(</span><span class="n">function</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_ty
    +<div class="viewcode-block" id="explicit_params_validation"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.explicit_params_validation">[docs]</a><span class="k">def</span> <span class="nf">explicit_params_validation</span><span class="p">(</span><span class="n">function</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_type</span><s
         <span class="k">if</span> <span class="n">function</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">function</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">return</span> <span class="n">_ExplicitParamsValidator</span><span class="p">(</span><span class="n">function</span><span class="o">=</span><span class="n">function</span><span class="p">)</span>
             <span class="k">return</span> <span class="n">_ExplicitParamsValidator</span><span class="p">(</span><span class="n">function</span><span class="o">=</span><span class="n">function</span><span class="p">)</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
    @@ -197,4 +199,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.decorators.singleton &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.decorators.singleton &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -86,7 +88,7 @@
                <div itemprop="articleBody">
                <div itemprop="articleBody">
                  
                  
       <h1>Source code for super_gradients.common.decorators.singleton</h1><div class="highlight"><pre>
       <h1>Source code for super_gradients.common.decorators.singleton</h1><div class="highlight"><pre>
    -<div class="viewcode-block" id="SingletonMeta"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.SingletonMeta">[docs]</a><span></span><span class="k">class</span> <span class="nc">SingletonMeta</span><span class="p">(</span><span class="nb">type</span><span class="p">):</span>
    +<span></span><span class="k">class</span> <span class="nc">SingletonMeta</span><span class="p">(</span><span class="nb">type</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    A Singleton meta class.</span>
     <span class="sd">    A Singleton meta class.</span>
     <span class="sd">    A class that derives from this class will have only 1 instance of that type for the process.</span>
     <span class="sd">    A class that derives from this class will have only 1 instance of that type for the process.</span>
    @@ -97,7 +99,7 @@
         <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
             <span class="k">if</span> <span class="bp">cls</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">cls</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">:</span>
                 <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">SingletonMeta</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span 
                 <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">SingletonMeta</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span 
    -        <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span></div>
    +        <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span>
     
     
     
     
     <span class="k">class</span> <span class="nc">_SingletonWrapper</span><span class="p">:</span>
     <span class="k">class</span> <span class="nc">_SingletonWrapper</span><span class="p">:</span>
    @@ -117,7 +119,7 @@
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instance</span>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instance</span>
     
     
     
     
    -<div class="viewcode-block" id="singleton"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.singleton">[docs]</a><span class="k">def</span> <span class="nf">singleton</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
    +<div class="viewcode-block" id="singleton"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.singleton">[docs]</a><span class="k">def</span> <span class="nf">singleton</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    A singleton decorator. Returns a wrapper objects. A call on that object</span>
     <span class="sd">    A singleton decorator. Returns a wrapper objects. A call on that object</span>
     <span class="sd">    returns a single instance object of decorated class. Use the __wrapped__</span>
     <span class="sd">    returns a single instance object of decorated class. Use the __wrapped__</span>
    @@ -153,4 +155,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.common.environment.env_helpers &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.common.environment.env_helpers &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -87,38 +89,56 @@
                  
                  
       <h1>Source code for super_gradients.common.environment.env_helpers</h1><div class="highlight"><pre>
       <h1>Source code for super_gradients.common.environment.env_helpers</h1><div class="highlight"><pre>
     <span></span><span class="kn">import</span> <span class="nn">argparse</span>
     <span></span><span class="kn">import</span> <span class="nn">argparse</span>
    +<span class="kn">import</span> <span class="nn">importlib</span>
     <span class="kn">import</span> <span class="nn">os</span>
     <span class="kn">import</span> <span class="nn">os</span>
    +<span class="kn">import</span> <span class="nn">socket</span>
     <span class="kn">import</span> <span class="nn">sys</span>
     <span class="kn">import</span> <span class="nn">sys</span>
     <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span>
     <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span>
    +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
    +
    +<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">OmegaConf</span>
     
     
     <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
     
     
     
     
    -<div class="viewcode-block" id="TerminalColours"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.TerminalColours">[docs]</a><span class="k">class</span> <span class="nc">TerminalColours</span><span class="p">:</span>
    +<span class="k">class</span> <span class="nc">TerminalColours</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Usage: https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python?page=1&amp;tab=votes#tab-top</span>
     <span class="sd">    Usage: https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python?page=1&amp;tab=votes#tab-top</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    -    <span class="n">HEADER</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[95m&#39;</span>
    -    <span class="n">OKBLUE</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[94m&#39;</span>
    -    <span class="n">OKCYAN</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[96m&#39;</span>
    -    <span class="n">OKGREEN</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[92m&#39;</span>
    -    <span class="n">WARNING</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[93m&#39;</span>
    -    <span class="n">FAIL</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[91m&#39;</span>
    -    <span class="n">ENDC</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[0m&#39;</span>
    -    <span class="n">BOLD</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[1m&#39;</span>
    -    <span class="n">UNDERLINE</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[4m&#39;</span></div>
    -
    -
    -<div class="viewcode-block" id="ColouredTextFormatter"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.ColouredTextFormatter">[docs]</a><span class="k">class</span> <span class="nc">ColouredTextFormatter</span><span class="p">:</span>
    -<div class="viewcode-block" id="ColouredTextFormatter.print_coloured_text"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.ColouredTextFormatter.print_coloured_text">[docs]</a>    <span class="nd">@staticmethod</span>
    +
    +    <span class="n">HEADER</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[95m&quot;</span>
    +    <span class="n">OKBLUE</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[94m&quot;</span>
    +    <span class="n">OKCYAN</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[96m&quot;</span>
    +    <span class="n">OKGREEN</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[92m&quot;</span>
    +    <span class="n">WARNING</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[93m&quot;</span>
    +    <span class="n">FAIL</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[91m&quot;</span>
    +    <span class="n">ENDC</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[0m&quot;</span>
    +    <span class="n">BOLD</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[1m&quot;</span>
    +    <span class="n">UNDERLINE</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[4m&quot;</span>
    +
    +
    +<span class="k">class</span> <span class="nc">ColouredTextFormatter</span><span class="p">:</span>
    +    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">print_coloured_text</span><span class="p">(</span><span class="n">text</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">colour</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">print_coloured_text</span><span class="p">(</span><span class="n">text</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">colour</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Prints a text with colour ascii characters.</span>
     <span class="sd">        Prints a text with colour ascii characters.</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">return</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">colour</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="n">TerminalColours</span><span class="o">.</span><span class="n">ENDC</span><span class="p">]))</span></div></div>
    +        <span class="k">return</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">colour</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="n">TerminalColours</span><span class="o">.</span><span class="n">ENDC</span><span class="p">]))</span>
    +
    +
    +<span class="k">def</span> <span class="nf">get_cls</span><span class="p">(</span><span class="n">cls_path</span><span class="p">):</span>
    +    <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">    A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.</span>
    +<span class="sd">    usage:</span>
    +<span class="sd">    class_of_optimizer: ${class:torch.optim.Adam}</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +    <span class="n">module</span> <span class="o">=</span> <span class="s2">&quot;.&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">cls_path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
    +    <span class="n">name</span> <span class="o">=</span> <span class="n">cls_path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    +    <span class="n">importlib</span><span class="o">.</span><span class="n">import_module</span><span class="p">(</span><span class="n">module</span><span class="p">)</span>
    +    <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">modules</span><span class="p">[</span><span class="n">module</span><span class="p">],</span> <span class="n">name</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="get_environ_as_type"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.get_environ_as_type">[docs]</a><span class="k">def</span> <span class="nf">get_environ_as_type</span><span class="p">(</span><span class="n">environment_variable_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">None</span><span cl
    +<span class="k">def</span> <span class="nf">get_environ_as_type</span><span class="p">(</span><span class="n">environment_variable_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cast_to_type</span><span class="p">:</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">str</span><span class="p">)</span> <span c
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Tries to get an environment variable and cast it into a requested type.</span>
     <span class="sd">    Tries to get an environment variable and cast it into a requested type.</span>
     <span class="sd">    :return: cast_to_type object, or None if failed.</span>
     <span class="sd">    :return: cast_to_type object, or None if failed.</span>
    @@ -131,41 +151,101 @@
             <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
             <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
                 <span class="nb">print</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>
                 <span class="nb">print</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>
                 <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
                 <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
    -                <span class="sa">f</span><span class="s1">&#39;Failed to cast environment variable </span><span class="si">{</span><span class="n">environment_variable_name</span><span class="si">}</span><span class="s1"> to type </span><span class="si">{</span><span class="n">cast_to_type</span><span class="si">}</span><span class="s1">: the value </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s1"> is not a valid </span><span class="si">{</spa
    -    <span class="k">return</span></div>
    +                <span class="sa">f</span><span class="s2">&quot;Failed to cast environment variable </span><span class="si">{</span><span class="n">environment_variable_name</span><span class="si">}</span><span class="s2"> to type </span><span class="si">{</span><span class="n">cast_to_type</span><span class="si">}</span><span class="s2">: the value </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s2"> is not a valid </span><span class="si">{</sp
    +            <span class="p">)</span>
    +    <span class="k">return</span>
    +
    +
    +<span class="k">def</span> <span class="nf">hydra_output_dir_resolver</span><span class="p">(</span><span class="n">ckpt_root_dir</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">):</span>
    +    <span class="k">if</span> <span class="n">ckpt_root_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +        <span class="n">output_dir_path</span> <span class="o">=</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">experiment_name</span>
    +    <span class="k">else</span><span class="p">:</span>
    +        <span class="n">output_dir_path</span> <span class="o">=</span> <span class="n">ckpt_root_dir</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">experiment_name</span>
    +    <span class="k">return</span> <span class="n">output_dir_path</span>
     
     
     
     
    -<div class="viewcode-block" id="init_trainer"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.init_trainer">[docs]</a><span class="k">def</span> <span class="nf">init_trainer</span><span class="p">():</span>
    +<div class="viewcode-block" id="init_trainer"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.init_trainer">[docs]</a><span class="k">def</span> <span class="nf">init_trainer</span><span class="p">():</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">    a function to initialize the super_gradients environment. This function should be the first thing to be called</span>
    -<span class="sd">    by any code running super_gradients. It resolves conflicts between the different tools, packages and environments used</span>
    -<span class="sd">    and prepares the super_gradients environment.</span>
    +<span class="sd">    Initialize the super_gradients environment.</span>
    +
    +<span class="sd">    This function should be the first thing to be called by any code running super_gradients.</span>
    +<span class="sd">    It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    +    <span class="k">if</span> <span class="ow">not</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">INIT_TRAINER</span><span class="p">:</span>
    +        <span class="n">register_hydra_resolvers</span><span class="p">()</span>
    +
    +        <span class="c1"># We pop local_rank if it was specified in the args, because it would break</span>
    +        <span class="n">args_local_rank</span> <span class="o">=</span> <span class="n">pop_arg</span><span class="p">(</span><span class="s2">&quot;local_rank&quot;</span><span class="p">,</span> <span class="n">default_value</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    +
    +        <span class="c1"># Set local_rank with priority order (env variable &gt; args.local_rank &gt; args.default_value)</span>
    +        <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getenv</span><span class="p">(</span><span class="s2">&quot;LOCAL_RANK&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">args_local_rank</span><span class="p">))</span>
    +        <span class="n">environment_config</span><span class="o">.</span><span class="n">INIT_TRAINER</span> <span class="o">=</span> <span class="kc">True</span></div>
    +
    +
    +<span class="k">def</span> <span class="nf">register_hydra_resolvers</span><span class="p">():</span>
    +    <span class="sd">&quot;&quot;&quot;Register all the hydra resolvers required for the super-gradients recipes.&quot;&quot;&quot;</span>
    +    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;hydra_output_dir&quot;</span><span class="p">,</span> <span class="n">hydra_output_dir_resolver</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;class&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">get_cls</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">),</span> <span class="n">replace</span><span class="o">=</span><span class="k
    +    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;add&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="nb">sum</span><span class="p">(</span><span class="n">args</span><span class="p">),</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p"
    +    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;cond&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">boolean</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="k">if</span> <span class="n">boolean</span> <span class="k">else</span> <span class="
    +    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;getitem&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">container</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">container</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">replace</span><span class="o">=</span><sp
    +    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;first&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">lst</span><span class="p">:</span> <span class="n">lst</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>  <span class="c1
    +    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;last&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">lst</span><span class="p">:</span> <span class="n">lst</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p">)<
    +
    +
    +<span class="k">def</span> <span class="nf">pop_arg</span><span class="p">(</span><span class="n">arg_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default_value</span><span class="p">:</span> <span class="n">Any</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
    +    <span class="sd">&quot;&quot;&quot;Get the specified args and remove them from argv&quot;&quot;&quot;</span>
     
     
         <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">()</span>
         <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">()</span>
    -    <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--local_rank&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># used by DDP</span>
    +    <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;--</span><span class="si">{</span><span class="n">arg_name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">default_value</span><span class="p">)</span>
         <span class="n">args</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_known_args</span><span class="p">()</span>
         <span class="n">args</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_known_args</span><span class="p">()</span>
     
     
    -    <span class="c1"># remove any flags starting with --local_rank from the argv list</span>
    -    <span class="n">to_remove</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;--local_rank&#39;</span><span class="p">),</span> <span class="n">sys</span><span class="o">.</span><span class="n">argv</spa
    -    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">to_remove</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    -        <span class="k">for</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">to_remove</span><span class="p">:</span>
    -            <span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
    -
    -    <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">local_rank</span></div>
    +    <span class="c1"># Remove the ddp args to not have a conflict with the use of hydra</span>
    +    <span class="k">for</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;--</span><span class="si">{</span><span class="n">arg_name</span><span class="si">}</span><span class="s2">&quot;</span><span cl
    +        <span class="n">environment_config</span><span class="o">.</span><span class="n">EXTRA_ARGS</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
    +        <span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
    +    <span class="k">return</span> <span class="nb">vars</span><span class="p">(</span><span class="n">args</span><span class="p">)[</span><span class="n">arg_name</span><span class="p">]</span>
     
     
     
     
    -<div class="viewcode-block" id="is_distributed"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.is_distributed">[docs]</a><span class="k">def</span> <span class="nf">is_distributed</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
    +<div class="viewcode-block" id="is_distributed"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.is_distributed">[docs]</a><span class="k">def</span> <span class="nf">is_distributed</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
         <span class="k">return</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">&gt;=</span> <span class="mi">0</span></div>
         <span class="k">return</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">&gt;=</span> <span class="mi">0</span></div>
     
     
     
     
    -<div class="viewcode-block" id="multi_process_safe"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.multi_process_safe">[docs]</a><span class="k">def</span> <span class="nf">multi_process_safe</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">is_rank_0</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
    +    <span class="sd">&quot;&quot;&quot;Check if the node was launched with torch.distributed.launch and if the node is of rank 0&quot;&quot;&quot;</span>
    +    <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">getenv</span><span class="p">(</span><span class="s2">&quot;LOCAL_RANK&quot;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&quot;0&quot;</span>
    +
    +
    +<span class="k">def</span> <span class="nf">is_launched_using_sg</span><span class="p">():</span>
    +    <span class="sd">&quot;&quot;&quot;Check if the current process is a subprocess launched using SG restart_script_with_ddp&quot;&quot;&quot;</span>
    +    <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;TORCHELASTIC_RUN_ID&quot;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&quot;sg_initiated&quot;</span>
    +
    +
    +<span class="k">def</span> <span class="nf">is_main_process</span><span class="p">():</span>
    +    <span class="sd">&quot;&quot;&quot;Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...).</span>
    +<span class="sd">    The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started.</span>
    +
    +<span class="sd">    The rule is as follow:</span>
    +<span class="sd">        - If not DDP: main process is current process</span>
    +<span class="sd">        - If DDP launched using SuperGradients: main process is the launching process (rank=-1)</span>
    +<span class="sd">        - If DDP launched with torch: main process is rank 0</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +    <span class="k">if</span> <span class="ow">not</span> <span class="n">is_distributed</span><span class="p">():</span>  <span class="c1"># If no DDP, or DDP launching process</span>
    +        <span class="k">return</span> <span class="kc">True</span>
    +    <span class="k">elif</span> <span class="n">is_rank_0</span><span class="p">()</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_launched_using_sg</span><span class="p">():</span>  <span class="c1"># If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0</span>
    +        <span class="k">return</span> <span class="kc">True</span>
    +    <span class="k">else</span><span class="p">:</span>
    +        <span class="k">return</span> <span class="kc">False</span>
    +
    +
    +<span class="k">def</span> <span class="nf">multi_process_safe</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    A decorator for making sure a function runs only in main process.</span>
     <span class="sd">    A decorator for making sure a function runs only in main process.</span>
     <span class="sd">    If not in DDP mode (local_rank = -1), the function will run.</span>
     <span class="sd">    If not in DDP mode (local_rank = -1), the function will run.</span>
     <span class="sd">    If in DDP mode, the function will run only in the main process (local_rank = 0)</span>
     <span class="sd">    If in DDP mode, the function will run only in the main process (local_rank = 0)</span>
     <span class="sd">    This works only for functions with no return value</span>
     <span class="sd">    This works only for functions with no return value</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    +
         <span class="k">def</span> <span class="nf">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
             <span class="k">pass</span>
             <span class="k">pass</span>
     
     
    @@ -176,7 +256,18 @@
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="k">return</span> <span class="n">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
                 <span class="k">return</span> <span class="n">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
     
     
    -    <span class="k">return</span> <span class="n">wrapper</span></div>
    +    <span class="k">return</span> <span class="n">wrapper</span>
    +
    +
    +<span class="k">def</span> <span class="nf">find_free_port</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
    +    <span class="sd">&quot;&quot;&quot;Find an available port of current machine/node.</span>
    +<span class="sd">    Note: there is still a chance the port could be taken by other processes.&quot;&quot;&quot;</span>
    +
    +    <span class="k">with</span> <span class="n">socket</span><span class="o">.</span><span class="n">socket</span><span class="p">(</span><span class="n">socket</span><span class="o">.</span><span class="n">AF_INET</span><span class="p">,</span> <span class="n">socket</span><span class="o">.</span><span class="n">SOCK_STREAM</span><span class="p">)</span> <span class="k">as</span> <span class="n">sock</span><span class="p">:</span>
    +        <span class="c1"># Binding to port 0 will cause the OS to find an available port for us</span>
    +        <span class="n">sock</span><span class="o">.</span><span class="n">bind</span><span class="p">((</span><span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
    +        <span class="n">_ip</span><span class="p">,</span> <span class="n">port</span> <span class="o">=</span> <span class="n">sock</span><span class="o">.</span><span class="n">getsockname</span><span class="p">()</span>
    +    <span class="k">return</span> <span class="n">port</span>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -206,4 +297,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.common.object_names &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
    15. <script src="../../../_static/jquery.js"></script>
    16. <script src="../../../_static/underscore.js"></script>
    17. <script src="../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../_static/doctools.js"></script>
    19. <script src="../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.common.object_names</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.common.object_names</h1><div class="highlight"><pre>
    85. <div class="viewcode-block" id="Losses"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.losses.Losses">[docs]</a><span></span><span class="k">class</span> <span class="nc">Losses</span><span class="p">:</span>
    86. <span class="sd">&quot;&quot;&quot;Static class holding all the supported loss names&quot;&quot;&quot;</span>
    87. <span class="n">CROSS_ENTROPY</span> <span class="o">=</span> <span class="s2">&quot;cross_entropy&quot;</span>
    88. <span class="n">MSE</span> <span class="o">=</span> <span class="s2">&quot;mse&quot;</span>
    89. <span class="n">R_SQUARED_LOSS</span> <span class="o">=</span> <span class="s2">&quot;r_squared_loss&quot;</span>
    90. <span class="n">SHELFNET_OHEM_LOSS</span> <span class="o">=</span> <span class="s2">&quot;shelfnet_ohem_loss&quot;</span>
    91. <span class="n">SHELFNET_SE_LOSS</span> <span class="o">=</span> <span class="s2">&quot;shelfnet_se_loss&quot;</span>
    92. <span class="n">YOLOX_LOSS</span> <span class="o">=</span> <span class="s2">&quot;yolox_loss&quot;</span>
    93. <span class="n">YOLOX_FAST_LOSS</span> <span class="o">=</span> <span class="s2">&quot;yolox_fast_loss&quot;</span>
    94. <span class="n">SSD_LOSS</span> <span class="o">=</span> <span class="s2">&quot;ssd_loss&quot;</span>
    95. <span class="n">STDC_LOSS</span> <span class="o">=</span> <span class="s2">&quot;stdc_loss&quot;</span>
    96. <span class="n">BCE_DICE_LOSS</span> <span class="o">=</span> <span class="s2">&quot;bce_dice_loss&quot;</span>
    97. <span class="n">KD_LOSS</span> <span class="o">=</span> <span class="s2">&quot;kd_loss&quot;</span>
    98. <span class="n">DICE_CE_EDGE_LOSS</span> <span class="o">=</span> <span class="s2">&quot;dice_ce_edge_loss&quot;</span></div>
    99. <div class="viewcode-block" id="Metrics"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.losses.Metrics">[docs]</a><span class="k">class</span> <span class="nc">Metrics</span><span class="p">:</span>
    100. <span class="sd">&quot;&quot;&quot;Static class holding all the supported metric names&quot;&quot;&quot;</span>
    101. <span class="n">ACCURACY</span> <span class="o">=</span> <span class="s2">&quot;Accuracy&quot;</span>
    102. <span class="n">TOP5</span> <span class="o">=</span> <span class="s2">&quot;Top5&quot;</span>
    103. <span class="n">DETECTION_METRICS</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics&quot;</span>
    104. <span class="n">DETECTION_METRICS_050_095</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics_050_095&quot;</span>
    105. <span class="n">DETECTION_METRICS_050</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics_050&quot;</span>
    106. <span class="n">DETECTION_METRICS_075</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics_075&quot;</span>
    107. <span class="n">IOU</span> <span class="o">=</span> <span class="s2">&quot;IoU&quot;</span>
    108. <span class="n">BINARY_IOU</span> <span class="o">=</span> <span class="s2">&quot;BinaryIOU&quot;</span>
    109. <span class="n">DICE</span> <span class="o">=</span> <span class="s2">&quot;Dice&quot;</span>
    110. <span class="n">BINARY_DICE</span> <span class="o">=</span> <span class="s2">&quot;BinaryDice&quot;</span>
    111. <span class="n">PIXEL_ACCURACY</span> <span class="o">=</span> <span class="s2">&quot;PixelAccuracy&quot;</span></div>
    112. <div class="viewcode-block" id="Transforms"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.losses.Transforms">[docs]</a><span class="k">class</span> <span class="nc">Transforms</span><span class="p">:</span>
    113. <span class="sd">&quot;&quot;&quot;Static class holding all the supported transform names&quot;&quot;&quot;</span>
    114. <span class="c1"># From SG</span>
    115. <span class="n">SegRandomFlip</span> <span class="o">=</span> <span class="s2">&quot;SegRandomFlip&quot;</span>
    116. <span class="n">SegResize</span> <span class="o">=</span> <span class="s2">&quot;SegResize&quot;</span>
    117. <span class="n">SegRescale</span> <span class="o">=</span> <span class="s2">&quot;SegRescale&quot;</span>
    118. <span class="n">SegRandomRescale</span> <span class="o">=</span> <span class="s2">&quot;SegRandomRescale&quot;</span>
    119. <span class="n">SegRandomRotate</span> <span class="o">=</span> <span class="s2">&quot;SegRandomRotate&quot;</span>
    120. <span class="n">SegCropImageAndMask</span> <span class="o">=</span> <span class="s2">&quot;SegCropImageAndMask&quot;</span>
    121. <span class="n">SegRandomGaussianBlur</span> <span class="o">=</span> <span class="s2">&quot;SegRandomGaussianBlur&quot;</span>
    122. <span class="n">SegPadShortToCropSize</span> <span class="o">=</span> <span class="s2">&quot;SegPadShortToCropSize&quot;</span>
    123. <span class="n">SegColorJitter</span> <span class="o">=</span> <span class="s2">&quot;SegColorJitter&quot;</span>
    124. <span class="n">DetectionMosaic</span> <span class="o">=</span> <span class="s2">&quot;DetectionMosaic&quot;</span>
    125. <span class="n">DetectionRandomAffine</span> <span class="o">=</span> <span class="s2">&quot;DetectionRandomAffine&quot;</span>
    126. <span class="n">DetectionMixup</span> <span class="o">=</span> <span class="s2">&quot;DetectionMixup&quot;</span>
    127. <span class="n">DetectionHSV</span> <span class="o">=</span> <span class="s2">&quot;DetectionHSV&quot;</span>
    128. <span class="n">DetectionHorizontalFlip</span> <span class="o">=</span> <span class="s2">&quot;DetectionHorizontalFlip&quot;</span>
    129. <span class="n">DetectionPaddedRescale</span> <span class="o">=</span> <span class="s2">&quot;DetectionPaddedRescale&quot;</span>
    130. <span class="n">DetectionTargetsFormat</span> <span class="o">=</span> <span class="s2">&quot;DetectionTargetsFormat&quot;</span>
    131. <span class="n">DetectionTargetsFormatTransform</span> <span class="o">=</span> <span class="s2">&quot;DetectionTargetsFormatTransform&quot;</span>
    132. <span class="n">RandomResizedCropAndInterpolation</span> <span class="o">=</span> <span class="s2">&quot;RandomResizedCropAndInterpolation&quot;</span>
    133. <span class="n">RandAugmentTransform</span> <span class="o">=</span> <span class="s2">&quot;RandAugmentTransform&quot;</span>
    134. <span class="n">Lighting</span> <span class="o">=</span> <span class="s2">&quot;Lighting&quot;</span>
    135. <span class="n">RandomErase</span> <span class="o">=</span> <span class="s2">&quot;RandomErase&quot;</span>
    136. <span class="c1"># From torch</span>
    137. <span class="n">Compose</span> <span class="o">=</span> <span class="s2">&quot;Compose&quot;</span>
    138. <span class="n">ToTensor</span> <span class="o">=</span> <span class="s2">&quot;ToTensor&quot;</span>
    139. <span class="n">PILToTensor</span> <span class="o">=</span> <span class="s2">&quot;PILToTensor&quot;</span>
    140. <span class="n">ConvertImageDtype</span> <span class="o">=</span> <span class="s2">&quot;ConvertImageDtype&quot;</span>
    141. <span class="n">ToPILImage</span> <span class="o">=</span> <span class="s2">&quot;ToPILImage&quot;</span>
    142. <span class="n">Normalize</span> <span class="o">=</span> <span class="s2">&quot;Normalize&quot;</span>
    143. <span class="n">Resize</span> <span class="o">=</span> <span class="s2">&quot;Resize&quot;</span>
    144. <span class="n">CenterCrop</span> <span class="o">=</span> <span class="s2">&quot;CenterCrop&quot;</span>
    145. <span class="n">Pad</span> <span class="o">=</span> <span class="s2">&quot;Pad&quot;</span>
    146. <span class="n">Lambda</span> <span class="o">=</span> <span class="s2">&quot;Lambda&quot;</span>
    147. <span class="n">RandomApply</span> <span class="o">=</span> <span class="s2">&quot;RandomApply&quot;</span>
    148. <span class="n">RandomChoice</span> <span class="o">=</span> <span class="s2">&quot;RandomChoice&quot;</span>
    149. <span class="n">RandomOrder</span> <span class="o">=</span> <span class="s2">&quot;RandomOrder&quot;</span>
    150. <span class="n">RandomCrop</span> <span class="o">=</span> <span class="s2">&quot;RandomCrop&quot;</span>
    151. <span class="n">RandomHorizontalFlip</span> <span class="o">=</span> <span class="s2">&quot;RandomHorizontalFlip&quot;</span>
    152. <span class="n">RandomVerticalFlip</span> <span class="o">=</span> <span class="s2">&quot;RandomVerticalFlip&quot;</span>
    153. <span class="n">RandomResizedCrop</span> <span class="o">=</span> <span class="s2">&quot;RandomResizedCrop&quot;</span>
    154. <span class="n">FiveCrop</span> <span class="o">=</span> <span class="s2">&quot;FiveCrop&quot;</span>
    155. <span class="n">TenCrop</span> <span class="o">=</span> <span class="s2">&quot;TenCrop&quot;</span>
    156. <span class="n">LinearTransformation</span> <span class="o">=</span> <span class="s2">&quot;LinearTransformation&quot;</span>
    157. <span class="n">ColorJitter</span> <span class="o">=</span> <span class="s2">&quot;ColorJitter&quot;</span>
    158. <span class="n">RandomRotation</span> <span class="o">=</span> <span class="s2">&quot;RandomRotation&quot;</span>
    159. <span class="n">RandomAffine</span> <span class="o">=</span> <span class="s2">&quot;RandomAffine&quot;</span>
    160. <span class="n">Grayscale</span> <span class="o">=</span> <span class="s2">&quot;Grayscale&quot;</span>
    161. <span class="n">RandomGrayscale</span> <span class="o">=</span> <span class="s2">&quot;RandomGrayscale&quot;</span>
    162. <span class="n">RandomPerspective</span> <span class="o">=</span> <span class="s2">&quot;RandomPerspective&quot;</span>
    163. <span class="n">RandomErasing</span> <span class="o">=</span> <span class="s2">&quot;RandomErasing&quot;</span>
    164. <span class="n">GaussianBlur</span> <span class="o">=</span> <span class="s2">&quot;GaussianBlur&quot;</span>
    165. <span class="n">InterpolationMode</span> <span class="o">=</span> <span class="s2">&quot;InterpolationMode&quot;</span>
    166. <span class="n">RandomInvert</span> <span class="o">=</span> <span class="s2">&quot;RandomInvert&quot;</span>
    167. <span class="n">RandomPosterize</span> <span class="o">=</span> <span class="s2">&quot;RandomPosterize&quot;</span>
    168. <span class="n">RandomSolarize</span> <span class="o">=</span> <span class="s2">&quot;RandomSolarize&quot;</span>
    169. <span class="n">RandomAdjustSharpness</span> <span class="o">=</span> <span class="s2">&quot;RandomAdjustSharpness&quot;</span>
    170. <span class="n">RandomAutocontrast</span> <span class="o">=</span> <span class="s2">&quot;RandomAutocontrast&quot;</span>
    171. <span class="n">RandomEqualize</span> <span class="o">=</span> <span class="s2">&quot;RandomEqualize&quot;</span></div>
    172. <span class="k">class</span> <span class="nc">Optimizers</span><span class="p">:</span>
    173. <span class="sd">&quot;&quot;&quot;Static class holding all the supported optimizer names&quot;&quot;&quot;</span>
    174. <span class="n">SGD</span> <span class="o">=</span> <span class="s2">&quot;SGD&quot;</span>
    175. <span class="n">ADAM</span> <span class="o">=</span> <span class="s2">&quot;Adam&quot;</span>
    176. <span class="n">RMS_PROP</span> <span class="o">=</span> <span class="s2">&quot;RMSprop&quot;</span>
    177. <span class="n">RMS_PROP_TF</span> <span class="o">=</span> <span class="s2">&quot;RMSpropTF&quot;</span>
    178. <span class="n">LAMB</span> <span class="o">=</span> <span class="s2">&quot;Lamb&quot;</span>
    179. <span class="k">class</span> <span class="nc">Callbacks</span><span class="p">:</span>
    180. <span class="sd">&quot;&quot;&quot;Static class holding all the supported callback names&quot;&quot;&quot;</span>
    181. <span class="n">DECI_LAB_UPLOAD</span> <span class="o">=</span> <span class="s2">&quot;DeciLabUploadCallback&quot;</span>
    182. <span class="n">LR_CALLBACK_BASE</span> <span class="o">=</span> <span class="s2">&quot;LRCallbackBase&quot;</span>
    183. <span class="n">LR_SCHEDULER</span> <span class="o">=</span> <span class="s2">&quot;LRSchedulerCallback&quot;</span>
    184. <span class="n">METRICS_UPDATE</span> <span class="o">=</span> <span class="s2">&quot;MetricsUpdateCallback&quot;</span>
    185. <span class="n">MODEL_CONVERSION_CHECK</span> <span class="o">=</span> <span class="s2">&quot;ModelConversionCheckCallback&quot;</span>
    186. <span class="n">EARLY_STOP</span> <span class="o">=</span> <span class="s2">&quot;EarlyStop&quot;</span>
    187. <span class="n">DETECTION_MULTISCALE_PREPREDICTION</span> <span class="o">=</span> <span class="s2">&quot;DetectionMultiscalePrePredictionCallback&quot;</span>
    188. <span class="n">YOLOX_TRAINING_STAGE_SWITCH</span> <span class="o">=</span> <span class="s2">&quot;YoloXTrainingStageSwitchCallback&quot;</span>
    189. <span class="k">class</span> <span class="nc">LRSchedulers</span><span class="p">:</span>
    190. <span class="sd">&quot;&quot;&quot;Static class to hold all the supported LR Scheduler names&quot;&quot;&quot;</span>
    191. <span class="n">STEP</span> <span class="o">=</span> <span class="s2">&quot;step&quot;</span>
    192. <span class="n">POLY</span> <span class="o">=</span> <span class="s2">&quot;poly&quot;</span>
    193. <span class="n">COSINE</span> <span class="o">=</span> <span class="s2">&quot;cosine&quot;</span>
    194. <span class="n">EXP</span> <span class="o">=</span> <span class="s2">&quot;exp&quot;</span>
    195. <span class="n">FUNCTION</span> <span class="o">=</span> <span class="s2">&quot;function&quot;</span>
    196. <span class="k">class</span> <span class="nc">LRWarmups</span><span class="p">:</span>
    197. <span class="sd">&quot;&quot;&quot;Static class to hold all the supported LR Warmup names&quot;&quot;&quot;</span>
    198. <span class="n">LINEAR_STEP</span> <span class="o">=</span> <span class="s2">&quot;linear_step&quot;</span>
    199. <span class="k">class</span> <span class="nc">Samplers</span><span class="p">:</span>
    200. <span class="sd">&quot;&quot;&quot;Static class to hold all the supported Samplers names&quot;&quot;&quot;</span>
    201. <span class="n">INFINITE</span> <span class="o">=</span> <span class="s2">&quot;InfiniteSampler&quot;</span>
    202. <span class="n">REPEAT_AUG</span> <span class="o">=</span> <span class="s2">&quot;RepeatAugSampler&quot;</span>
    203. <span class="n">DISTRIBUTED</span> <span class="o">=</span> <span class="s2">&quot;DistributedSampler&quot;</span>
    204. <span class="k">class</span> <span class="nc">ContextModules</span><span class="p">:</span>
    205. <span class="sd">&quot;&quot;&quot;Static class to hold all the segmentation context module names&quot;&quot;&quot;</span>
    206. <span class="n">ASPP</span> <span class="o">=</span> <span class="s2">&quot;ASPP&quot;</span>
    207. <span class="n">SPPM</span> <span class="o">=</span> <span class="s2">&quot;SPPM&quot;</span>
    208. <span class="k">class</span> <span class="nc">Models</span><span class="p">:</span>
    209. <span class="sd">&quot;&quot;&quot;Static class to hold all the available model names&quot;&quot;&quot;</span>
    210. <span class="n">RESNET18</span> <span class="o">=</span> <span class="s2">&quot;resnet18&quot;</span>
    211. <span class="n">RESNET34</span> <span class="o">=</span> <span class="s2">&quot;resnet34&quot;</span>
    212. <span class="n">RESNET50_3343</span> <span class="o">=</span> <span class="s2">&quot;resnet50_3343&quot;</span>
    213. <span class="n">RESNET50</span> <span class="o">=</span> <span class="s2">&quot;resnet50&quot;</span>
    214. <span class="n">RESNET101</span> <span class="o">=</span> <span class="s2">&quot;resnet101&quot;</span>
    215. <span class="n">RESNET152</span> <span class="o">=</span> <span class="s2">&quot;resnet152&quot;</span>
    216. <span class="n">RESNET18_CIFAR</span> <span class="o">=</span> <span class="s2">&quot;resnet18_cifar&quot;</span>
    217. <span class="n">CUSTOM_RESNET</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet&quot;</span>
    218. <span class="n">CUSTOM_RESNET50</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet50&quot;</span>
    219. <span class="n">CUSTOM_RESNET_CIFAR</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet_cifar&quot;</span>
    220. <span class="n">CUSTOM_RESNET50_CIFAR</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet50_cifar&quot;</span>
    221. <span class="n">MOBILENET_V2</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v2&quot;</span>
    222. <span class="n">MOBILE_NET_V2_135</span> <span class="o">=</span> <span class="s2">&quot;mobile_net_v2_135&quot;</span>
    223. <span class="n">CUSTOM_MOBILENET_V2</span> <span class="o">=</span> <span class="s2">&quot;custom_mobilenet_v2&quot;</span>
    224. <span class="n">MOBILENET_V3_LARGE</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v3_large&quot;</span>
    225. <span class="n">MOBILENET_V3_SMALL</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v3_small&quot;</span>
    226. <span class="n">MOBILENET_V3_CUSTOM</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v3_custom&quot;</span>
    227. <span class="n">CUSTOM_DENSENET</span> <span class="o">=</span> <span class="s2">&quot;custom_densenet&quot;</span>
    228. <span class="n">DENSENET121</span> <span class="o">=</span> <span class="s2">&quot;densenet121&quot;</span>
    229. <span class="n">DENSENET161</span> <span class="o">=</span> <span class="s2">&quot;densenet161&quot;</span>
    230. <span class="n">DENSENET169</span> <span class="o">=</span> <span class="s2">&quot;densenet169&quot;</span>
    231. <span class="n">DENSENET201</span> <span class="o">=</span> <span class="s2">&quot;densenet201&quot;</span>
    232. <span class="n">SHELFNET18_LW</span> <span class="o">=</span> <span class="s2">&quot;shelfnet18_lw&quot;</span>
    233. <span class="n">SHELFNET34_LW</span> <span class="o">=</span> <span class="s2">&quot;shelfnet34_lw&quot;</span>
    234. <span class="n">SHELFNET50_3343</span> <span class="o">=</span> <span class="s2">&quot;shelfnet50_3343&quot;</span>
    235. <span class="n">SHELFNET50</span> <span class="o">=</span> <span class="s2">&quot;shelfnet50&quot;</span>
    236. <span class="n">SHELFNET101</span> <span class="o">=</span> <span class="s2">&quot;shelfnet101&quot;</span>
    237. <span class="n">SHUFFLENET_V2_X0_5</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x0_5&quot;</span>
    238. <span class="n">SHUFFLENET_V2_X1_0</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x1_0&quot;</span>
    239. <span class="n">SHUFFLENET_V2_X1_5</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x1_5&quot;</span>
    240. <span class="n">SHUFFLENET_V2_X2_0</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x2_0&quot;</span>
    241. <span class="n">SHUFFLENET_V2_CUSTOM5</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_custom5&quot;</span>
    242. <span class="n">DARKNET53</span> <span class="o">=</span> <span class="s2">&quot;darknet53&quot;</span>
    243. <span class="n">CSP_DARKNET53</span> <span class="o">=</span> <span class="s2">&quot;csp_darknet53&quot;</span>
    244. <span class="n">RESNEXT50</span> <span class="o">=</span> <span class="s2">&quot;resnext50&quot;</span>
    245. <span class="n">RESNEXT101</span> <span class="o">=</span> <span class="s2">&quot;resnext101&quot;</span>
    246. <span class="n">GOOGLENET_V1</span> <span class="o">=</span> <span class="s2">&quot;googlenet_v1&quot;</span>
    247. <span class="n">EFFICIENTNET_B0</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b0&quot;</span>
    248. <span class="n">EFFICIENTNET_B1</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b1&quot;</span>
    249. <span class="n">EFFICIENTNET_B2</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b2&quot;</span>
    250. <span class="n">EFFICIENTNET_B3</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b3&quot;</span>
    251. <span class="n">EFFICIENTNET_B4</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b4&quot;</span>
    252. <span class="n">EFFICIENTNET_B5</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b5&quot;</span>
    253. <span class="n">EFFICIENTNET_B6</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b6&quot;</span>
    254. <span class="n">EFFICIENTNET_B7</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b7&quot;</span>
    255. <span class="n">EFFICIENTNET_B8</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b8&quot;</span>
    256. <span class="n">EFFICIENTNET_L2</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_l2&quot;</span>
    257. <span class="n">CUSTOMIZEDEFFICIENTNET</span> <span class="o">=</span> <span class="s2">&quot;CustomizedEfficientnet&quot;</span>
    258. <span class="n">REGNETY200</span> <span class="o">=</span> <span class="s2">&quot;regnetY200&quot;</span>
    259. <span class="n">REGNETY400</span> <span class="o">=</span> <span class="s2">&quot;regnetY400&quot;</span>
    260. <span class="n">REGNETY600</span> <span class="o">=</span> <span class="s2">&quot;regnetY600&quot;</span>
    261. <span class="n">REGNETY800</span> <span class="o">=</span> <span class="s2">&quot;regnetY800&quot;</span>
    262. <span class="n">CUSTOM_REGNET</span> <span class="o">=</span> <span class="s2">&quot;custom_regnet&quot;</span>
    263. <span class="n">CUSTOM_ANYNET</span> <span class="o">=</span> <span class="s2">&quot;custom_anynet&quot;</span>
    264. <span class="n">NAS_REGNET</span> <span class="o">=</span> <span class="s2">&quot;nas_regnet&quot;</span>
    265. <span class="n">YOLOX_N</span> <span class="o">=</span> <span class="s2">&quot;yolox_n&quot;</span>
    266. <span class="n">YOLOX_T</span> <span class="o">=</span> <span class="s2">&quot;yolox_t&quot;</span>
    267. <span class="n">YOLOX_S</span> <span class="o">=</span> <span class="s2">&quot;yolox_s&quot;</span>
    268. <span class="n">YOLOX_M</span> <span class="o">=</span> <span class="s2">&quot;yolox_m&quot;</span>
    269. <span class="n">YOLOX_L</span> <span class="o">=</span> <span class="s2">&quot;yolox_l&quot;</span>
    270. <span class="n">YOLOX_X</span> <span class="o">=</span> <span class="s2">&quot;yolox_x&quot;</span>
    271. <span class="n">CUSTOM_YOLO_X</span> <span class="o">=</span> <span class="s2">&quot;custom_yolox&quot;</span>
    272. <span class="n">SSD_MOBILENET_V1</span> <span class="o">=</span> <span class="s2">&quot;ssd_mobilenet_v1&quot;</span>
    273. <span class="n">SSD_LITE_MOBILENET_V2</span> <span class="o">=</span> <span class="s2">&quot;ssd_lite_mobilenet_v2&quot;</span>
    274. <span class="n">REPVGG_A0</span> <span class="o">=</span> <span class="s2">&quot;repvgg_a0&quot;</span>
    275. <span class="n">REPVGG_A1</span> <span class="o">=</span> <span class="s2">&quot;repvgg_a1&quot;</span>
    276. <span class="n">REPVGG_A2</span> <span class="o">=</span> <span class="s2">&quot;repvgg_a2&quot;</span>
    277. <span class="n">REPVGG_B0</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b0&quot;</span>
    278. <span class="n">REPVGG_B1</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b1&quot;</span>
    279. <span class="n">REPVGG_B2</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b2&quot;</span>
    280. <span class="n">REPVGG_B3</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b3&quot;</span>
    281. <span class="n">REPVGG_D2SE</span> <span class="o">=</span> <span class="s2">&quot;repvgg_d2se&quot;</span>
    282. <span class="n">REPVGG_CUSTOM</span> <span class="o">=</span> <span class="s2">&quot;repvgg_custom&quot;</span>
    283. <span class="n">DDRNET_23</span> <span class="o">=</span> <span class="s2">&quot;ddrnet_23&quot;</span>
    284. <span class="n">DDRNET_23_SLIM</span> <span class="o">=</span> <span class="s2">&quot;ddrnet_23_slim&quot;</span>
    285. <span class="n">CUSTOM_DDRNET_23</span> <span class="o">=</span> <span class="s2">&quot;custom_ddrnet_23&quot;</span>
    286. <span class="n">STDC1_CLASSIFICATION</span> <span class="o">=</span> <span class="s2">&quot;stdc1_classification&quot;</span>
    287. <span class="n">STDC2_CLASSIFICATION</span> <span class="o">=</span> <span class="s2">&quot;stdc2_classification&quot;</span>
    288. <span class="n">STDC1_SEG</span> <span class="o">=</span> <span class="s2">&quot;stdc1_seg&quot;</span>
    289. <span class="n">STDC1_SEG50</span> <span class="o">=</span> <span class="s2">&quot;stdc1_seg50&quot;</span>
    290. <span class="n">STDC1_SEG75</span> <span class="o">=</span> <span class="s2">&quot;stdc1_seg75&quot;</span>
    291. <span class="n">STDC2_SEG</span> <span class="o">=</span> <span class="s2">&quot;stdc2_seg&quot;</span>
    292. <span class="n">STDC2_SEG50</span> <span class="o">=</span> <span class="s2">&quot;stdc2_seg50&quot;</span>
    293. <span class="n">STDC2_SEG75</span> <span class="o">=</span> <span class="s2">&quot;stdc2_seg75&quot;</span>
    294. <span class="n">CUSTOM_STDC</span> <span class="o">=</span> <span class="s2">&quot;custom_stdc&quot;</span>
    295. <span class="n">REGSEG48</span> <span class="o">=</span> <span class="s2">&quot;regseg48&quot;</span>
    296. <span class="n">KD_MODULE</span> <span class="o">=</span> <span class="s2">&quot;kd_module&quot;</span>
    297. <span class="n">VIT_BASE</span> <span class="o">=</span> <span class="s2">&quot;vit_base&quot;</span>
    298. <span class="n">VIT_LARGE</span> <span class="o">=</span> <span class="s2">&quot;vit_large&quot;</span>
    299. <span class="n">VIT_HUGE</span> <span class="o">=</span> <span class="s2">&quot;vit_huge&quot;</span>
    300. <span class="n">BEIT_BASE_PATCH16_224</span> <span class="o">=</span> <span class="s2">&quot;beit_base_patch16_224&quot;</span>
    301. <span class="n">BEIT_LARGE_PATCH16_224</span> <span class="o">=</span> <span class="s2">&quot;beit_large_patch16_224&quot;</span>
    302. <span class="n">PP_LITE_T_SEG</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_t_seg&quot;</span>
    303. <span class="n">PP_LITE_T_SEG50</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_t_seg50&quot;</span>
    304. <span class="n">PP_LITE_T_SEG75</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_t_seg75&quot;</span>
    305. <span class="n">PP_LITE_B_SEG</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_b_seg&quot;</span>
    306. <span class="n">PP_LITE_B_SEG50</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_b_seg50&quot;</span>
    307. <span class="n">PP_LITE_B_SEG75</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_b_seg75&quot;</span>
    308. <span class="n">UNET_CUSTOM</span> <span class="o">=</span> <span class="s2">&quot;unet_custom&quot;</span>
    309. </pre></div>
    310. </div>
    311. </div>
    312. <footer>
    313. <hr/>
    314. <div role="contentinfo">
    315. <p>&#169; Copyright 2021, SuperGradients team.</p>
    316. </div>
    317. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    318. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    319. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    320. </footer>
    321. </div>
    322. </div>
    323. </section>
    324. </div>
    325. <script>
    326. jQuery(function () {
    327. SphinxRtdTheme.Navigation.enable(true);
    328. });
    329. </script>
    330. </body>
    331. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    559
    560
    561
    562
    563
    564
    565
    566
    567
    568
    569
    570
    571
    572
    573
    574
    575
    576
    577
    578
    579
    580
    581
    582
    583
    584
    585
    586
    587
    588
    589
    590
    591
    592
    593
    594
    595
    596
    597
    598
    599
    600
    601
    602
    603
    604
    605
    606
    607
    608
    609
    610
    611
    612
    613
    614
    615
    616
    617
    618
    619
    620
    621
    622
    623
    624
    625
    626
    627
    628
    629
    630
    631
    632
    633
    634
    635
    636
    637
    638
    639
    640
    641
    642
    643
    644
    645
    646
    647
    648
    649
    650
    651
    652
    653
    654
    655
    656
    657
    658
    659
    660
    661
    662
    663
    664
    665
    666
    667
    668
    669
    670
    671
    672
    673
    674
    675
    676
    677
    678
    679
    680
    681
    682
    683
    684
    685
    686
    687
    688
    689
    690
    691
    692
    693
    694
    695
    696
    697
    698
    699
    700
    701
    702
    703
    704
    705
    706
    707
    708
    709
    710
    711
    712
    713
    714
    715
    716
    717
    718
    719
    720
    721
    722
    723
    724
    725
    726
    727
    728
    729
    730
    731
    732
    733
    734
    735
    736
    737
    738
    739
    740
    741
    742
    743
    744
    745
    746
    747
    748
    749
    750
    751
    752
    753
    754
    755
    756
    757
    758
    759
    760
    761
    762
    763
    764
    765
    766
    767
    768
    769
    770
    771
    772
    773
    774
    775
    776
    777
    778
    779
    780
    781
    782
    783
    784
    785
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.dataloaders.dataloaders &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../_static/jquery.js"></script>
    16. <script src="../../../../_static/underscore.js"></script>
    17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../_static/doctools.js"></script>
    19. <script src="../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.training.dataloaders.dataloaders</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.training.dataloaders.dataloaders</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">import</span> <span class="nn">os.path</span>
    86. <span class="kn">import</span> <span class="nn">pkg_resources</span>
    87. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
    88. <span class="kn">import</span> <span class="nn">hydra</span>
    89. <span class="kn">from</span> <span class="nn">hydra</span> <span class="kn">import</span> <span class="n">compose</span><span class="p">,</span> <span class="n">initialize_config_dir</span>
    90. <span class="kn">from</span> <span class="nn">hydra.core.global_hydra</span> <span class="kn">import</span> <span class="n">GlobalHydra</span>
    91. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    92. <span class="kn">import</span> <span class="nn">torch</span>
    93. <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">BatchSampler</span><span class="p">,</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">TensorDataset</span>
    94. <span class="kn">import</span> <span class="nn">super_gradients</span>
    95. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.pascal_voc_detection</span> <span class="kn">import</span> <span class="p">(</span>
    96. <span class="n">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">,</span>
    97. <span class="n">PascalVOCDetectionDataset</span><span class="p">,</span>
    98. <span class="p">)</span>
    99. <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
    100. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.hydra_utils</span> <span class="kn">import</span> <span class="n">normalize_path</span>
    101. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets</span> <span class="kn">import</span> <span class="n">ImageNetDataset</span>
    102. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets</span> <span class="kn">import</span> <span class="n">COCODetectionDataset</span>
    103. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.classification_datasets.cifar</span> <span class="kn">import</span> <span class="p">(</span>
    104. <span class="n">Cifar10</span><span class="p">,</span>
    105. <span class="n">Cifar100</span><span class="p">,</span>
    106. <span class="p">)</span>
    107. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets</span> <span class="kn">import</span> <span class="p">(</span>
    108. <span class="n">CityscapesDataset</span><span class="p">,</span>
    109. <span class="n">CoCoSegmentationDataSet</span><span class="p">,</span>
    110. <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span>
    111. <span class="n">PascalVOCAndAUGUnifiedDataset</span><span class="p">,</span>
    112. <span class="n">SuperviselyPersonsDataset</span><span class="p">,</span>
    113. <span class="p">)</span>
    114. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.samplers_factory</span> <span class="kn">import</span> <span class="n">SamplersFactory</span>
    115. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="p">(</span>
    116. <span class="n">wait_for_the_master</span><span class="p">,</span>
    117. <span class="n">get_local_rank</span><span class="p">,</span>
    118. <span class="p">)</span>
    119. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    120. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">override_default_params_without_nones</span>
    121. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    122. <div class="viewcode-block" id="get_data_loader"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.get_data_loader">[docs]</a><span class="k">def</span> <span class="nf">get_data_loader</span><span class="p">(</span><span class="n">config_name</span><span class="p">,</span> <span class="n">dataset_cls</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    123. <span class="sd">&quot;&quot;&quot;</span>
    124. <span class="sd"> Class for creating dataloaders for taking defaults from yaml files in src/super_gradients/recipes.</span>
    125. <span class="sd"> :param config_name: yaml config filename in recipes (for example coco2017_yolox).</span>
    126. <span class="sd"> :param dataset_cls: torch dataset uninitialized class.</span>
    127. <span class="sd"> :param train: controls whether to take</span>
    128. <span class="sd"> cfg.dataset_params.train_dataloader_params or cfg.dataset_params.valid_dataloader_params as defaults for the dataset constructor</span>
    129. <span class="sd"> and</span>
    130. <span class="sd"> cfg.dataset_params.train_dataset_params or cfg.dataset_params.valid_dataset_params as defaults for DataLoader contructor.</span>
    131. <span class="sd"> :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.</span>
    132. <span class="sd"> :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__</span>
    133. <span class="sd"> :return: DataLoader</span>
    134. <span class="sd"> &quot;&quot;&quot;</span>
    135. <span class="k">if</span> <span class="n">dataloader_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    136. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    137. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    138. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    139. <span class="n">GlobalHydra</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
    140. <span class="n">sg_recipes_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;super_gradients.recipes&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">)</span>
    141. <span class="n">dataset_config</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">&quot;dataset_params&quot;</span><span class="p">,</span> <span class="n">config_name</span><span class="p">)</span>
    142. <span class="k">with</span> <span class="n">initialize_config_dir</span><span class="p">(</span><span class="n">config_dir</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">sg_recipes_dir</span><span class="p">),</span> <span class="n">version_base</span><span class="o">=</span><span class="s2">&quot;1.2&quot;</span><span class="p">):</span>
    143. <span class="c1"># config is relative to a module</span>
    144. <span class="n">cfg</span> <span class="o">=</span> <span class="n">compose</span><span class="p">(</span><span class="n">config_name</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">dataset_config</span><span class="p">))</span>
    145. <span class="n">dataset_params</span> <span class="o">=</span> <span class="n">_process_dataset_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataset_params</span><span class="p">,</span> <span class="n">train</span><span class="p">)</span>
    146. <span class="n">local_rank</span> <span class="o">=</span> <span class="n">get_local_rank</span><span class="p">()</span>
    147. <span class="k">with</span> <span class="n">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">):</span>
    148. <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="o">**</span><span class="n">dataset_params</span><span class="p">)</span>
    149. <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="s2">&quot;dataset_params&quot;</span><span class="p">):</span>
    150. <span class="n">dataset</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="n">dataset_params</span>
    151. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_process_dataloader_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">train</span><span class="p">)</span>
    152. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="o">**</span><span class="n">dataloader_params</span><span class="p">)</span>
    153. <span class="n">dataloader</span><span class="o">.</span><span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">dataloader_params</span>
    154. <span class="k">return</span> <span class="n">dataloader</span></div>
    155. <span class="k">def</span> <span class="nf">_process_dataset_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataset_params</span><span class="p">,</span> <span class="n">train</span><span class="p">):</span>
    156. <span class="n">default_dataset_params</span> <span class="o">=</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataset_params</span> <span class="k">if</span> <span class="n">train</span> <span class="k">else</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span>
    157. <span class="n">default_dataset_params</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">default_dataset_params</span><span class="p">)</span>
    158. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">default_dataset_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    159. <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="ow">or</span> <span class="n">dataset_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    160. <span class="n">dataset_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
    161. <span class="k">return</span> <span class="n">dataset_params</span>
    162. <span class="k">def</span> <span class="nf">_process_dataloader_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">train</span><span class="p">):</span>
    163. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataloader_params</span> <span class="k">if</span> <span class="n">train</span> <span class="k">else</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataloader_params</span>
    164. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">default_dataloader_params</span><span class="p">)</span>
    165. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_process_sampler_params</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
    166. <span class="k">return</span> <span class="n">dataloader_params</span>
    167. <span class="k">def</span> <span class="nf">_process_sampler_params</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">):</span>
    168. <span class="n">is_dist</span> <span class="o">=</span> <span class="n">super_gradients</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">()</span>
    169. <span class="k">if</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="s2">&quot;sampler&quot;</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    170. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">)</span>
    171. <span class="k">elif</span> <span class="n">get_param</span><span class="p">(</span><span class="n">default_dataloader_params</span><span class="p">,</span> <span class="s2">&quot;sampler&quot;</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    172. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
    173. <span class="k">elif</span> <span class="n">is_dist</span><span class="p">:</span>
    174. <span class="n">default_dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;DistributedSampler&quot;</span><span class="p">:</span> <span class="p">{}}</span>
    175. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
    176. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">override_default_params_without_nones</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
    177. <span class="k">if</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="s2">&quot;batch_sampler&quot;</span><span class="p">):</span>
    178. <span class="n">sampler</span> <span class="o">=</span> <span class="n">dataloader_params</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;sampler&quot;</span><span class="p">)</span>
    179. <span class="n">batch_size</span> <span class="o">=</span> <span class="n">dataloader_params</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;batch_size&quot;</span><span class="p">)</span>
    180. <span class="k">if</span> <span class="s2">&quot;drop_last&quot;</span> <span class="ow">in</span> <span class="n">dataloader_params</span><span class="p">:</span>
    181. <span class="n">drop_last</span> <span class="o">=</span> <span class="n">dataloader_params</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;drop_last&quot;</span><span class="p">)</span>
    182. <span class="k">else</span><span class="p">:</span>
    183. <span class="n">drop_last</span> <span class="o">=</span> <span class="n">default_dataloader_params</span><span class="p">[</span><span class="s2">&quot;drop_last&quot;</span><span class="p">]</span>
    184. <span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;batch_sampler&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">BatchSampler</span><span class="p">(</span><span class="n">sampler</span><span class="o">=</span><span class="n">sampler</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">drop_last</span><span class="o">=</span><span class="n">drop_last</span><span class="p">)</span>
    185. <span class="k">return</span> <span class="n">dataloader_params</span>
    186. <span class="k">def</span> <span class="nf">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">):</span>
    187. <span class="n">sampler_name</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
    188. <span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">][</span><span class="n">sampler_name</span><span class="p">][</span><span class="s2">&quot;dataset&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span>
    189. <span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">SamplersFactory</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">])</span>
    190. <span class="k">return</span> <span class="n">dataloader_params</span>
    191. <div class="viewcode-block" id="coco2017_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_train">[docs]</a><span class="k">def</span> <span class="nf">coco2017_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    192. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    193. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_dataset_params&quot;</span><span class="p">,</span>
    194. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
    195. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    196. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    197. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    198. <span class="p">)</span></div>
    199. <div class="viewcode-block" id="coco2017_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_val">[docs]</a><span class="k">def</span> <span class="nf">coco2017_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    200. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    201. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_dataset_params&quot;</span><span class="p">,</span>
    202. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
    203. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    204. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    205. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    206. <span class="p">)</span></div>
    207. <div class="viewcode-block" id="coco2017_train_yolox"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_train_yolox">[docs]</a><span class="k">def</span> <span class="nf">coco2017_train_yolox</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    208. <span class="k">return</span> <span class="n">coco2017_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">)</span></div>
    209. <div class="viewcode-block" id="coco2017_val_yolox"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_val_yolox">[docs]</a><span class="k">def</span> <span class="nf">coco2017_val_yolox</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    210. <span class="k">return</span> <span class="n">coco2017_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">)</span></div>
    211. <div class="viewcode-block" id="coco2017_train_ssd_lite_mobilenet_v2"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_train_ssd_lite_mobilenet_v2">[docs]</a><span class="k">def</span> <span class="nf">coco2017_train_ssd_lite_mobilenet_v2</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    212. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    213. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_ssd_lite_mobilenet_v2_dataset_params&quot;</span><span class="p">,</span>
    214. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
    215. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    216. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    217. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    218. <span class="p">)</span></div>
    219. <div class="viewcode-block" id="coco2017_val_ssd_lite_mobilenet_v2"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_val_ssd_lite_mobilenet_v2">[docs]</a><span class="k">def</span> <span class="nf">coco2017_val_ssd_lite_mobilenet_v2</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    220. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    221. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_ssd_lite_mobilenet_v2_dataset_params&quot;</span><span class="p">,</span>
    222. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
    223. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    224. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    225. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    226. <span class="p">)</span></div>
    227. <div class="viewcode-block" id="imagenet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_dataset_params&quot;</span><span class="p">):</span>
    228. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    229. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
    230. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
    231. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    232. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    233. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    234. <span class="p">)</span></div>
    235. <div class="viewcode-block" id="imagenet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_dataset_params&quot;</span><span class="p">):</span>
    236. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    237. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
    238. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
    239. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    240. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    241. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    242. <span class="p">)</span></div>
    243. <div class="viewcode-block" id="imagenet_efficientnet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_efficientnet_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_efficientnet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    244. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
    245. <span class="n">dataset_params</span><span class="p">,</span>
    246. <span class="n">dataloader_params</span><span class="p">,</span>
    247. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_efficientnet_dataset_params&quot;</span><span class="p">,</span>
    248. <span class="p">)</span></div>
    249. <div class="viewcode-block" id="imagenet_efficientnet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_efficientnet_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_efficientnet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    250. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
    251. <span class="n">dataset_params</span><span class="p">,</span>
    252. <span class="n">dataloader_params</span><span class="p">,</span>
    253. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_efficientnet_dataset_params&quot;</span><span class="p">,</span>
    254. <span class="p">)</span></div>
    255. <div class="viewcode-block" id="imagenet_mobilenetv2_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv2_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv2_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    256. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
    257. <span class="n">dataset_params</span><span class="p">,</span>
    258. <span class="n">dataloader_params</span><span class="p">,</span>
    259. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv2_dataset_params&quot;</span><span class="p">,</span>
    260. <span class="p">)</span></div>
    261. <div class="viewcode-block" id="imagenet_mobilenetv2_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv2_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv2_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    262. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
    263. <span class="n">dataset_params</span><span class="p">,</span>
    264. <span class="n">dataloader_params</span><span class="p">,</span>
    265. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv2_dataset_params&quot;</span><span class="p">,</span>
    266. <span class="p">)</span></div>
    267. <div class="viewcode-block" id="imagenet_mobilenetv3_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv3_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    268. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
    269. <span class="n">dataset_params</span><span class="p">,</span>
    270. <span class="n">dataloader_params</span><span class="p">,</span>
    271. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv3_dataset_params&quot;</span><span class="p">,</span>
    272. <span class="p">)</span></div>
    273. <div class="viewcode-block" id="imagenet_mobilenetv3_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv3_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    274. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
    275. <span class="n">dataset_params</span><span class="p">,</span>
    276. <span class="n">dataloader_params</span><span class="p">,</span>
    277. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv3_dataset_params&quot;</span><span class="p">,</span>
    278. <span class="p">)</span></div>
    279. <div class="viewcode-block" id="imagenet_regnetY_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_regnetY_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_regnetY_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    280. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_regnetY_dataset_params&quot;</span><span class="p">)</span></div>
    281. <div class="viewcode-block" id="imagenet_regnetY_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_regnetY_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_regnetY_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    282. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_regnetY_dataset_params&quot;</span><span class="p">)</span></div>
    283. <div class="viewcode-block" id="imagenet_resnet50_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    284. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
    285. <span class="n">dataset_params</span><span class="p">,</span>
    286. <span class="n">dataloader_params</span><span class="p">,</span>
    287. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_dataset_params&quot;</span><span class="p">,</span>
    288. <span class="p">)</span></div>
    289. <div class="viewcode-block" id="imagenet_resnet50_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    290. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
    291. <span class="n">dataset_params</span><span class="p">,</span>
    292. <span class="n">dataloader_params</span><span class="p">,</span>
    293. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_dataset_params&quot;</span><span class="p">,</span>
    294. <span class="p">)</span></div>
    295. <div class="viewcode-block" id="imagenet_resnet50_kd_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_kd_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_kd_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    296. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
    297. <span class="n">dataset_params</span><span class="p">,</span>
    298. <span class="n">dataloader_params</span><span class="p">,</span>
    299. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_kd_dataset_params&quot;</span><span class="p">,</span>
    300. <span class="p">)</span></div>
    301. <div class="viewcode-block" id="imagenet_resnet50_kd_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_kd_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_kd_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    302. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
    303. <span class="n">dataset_params</span><span class="p">,</span>
    304. <span class="n">dataloader_params</span><span class="p">,</span>
    305. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_kd_dataset_params&quot;</span><span class="p">,</span>
    306. <span class="p">)</span></div>
    307. <div class="viewcode-block" id="imagenet_vit_base_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_vit_base_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_base_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    308. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
    309. <span class="n">dataset_params</span><span class="p">,</span>
    310. <span class="n">dataloader_params</span><span class="p">,</span>
    311. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_vit_base_dataset_params&quot;</span><span class="p">,</span>
    312. <span class="p">)</span></div>
    313. <div class="viewcode-block" id="imagenet_vit_base_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_vit_base_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_base_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    314. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
    315. <span class="n">dataset_params</span><span class="p">,</span>
    316. <span class="n">dataloader_params</span><span class="p">,</span>
    317. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_vit_base_dataset_params&quot;</span><span class="p">,</span>
    318. <span class="p">)</span></div>
    319. <div class="viewcode-block" id="tiny_imagenet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.tiny_imagenet_train">[docs]</a><span class="k">def</span> <span class="nf">tiny_imagenet_train</span><span class="p">(</span>
    320. <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    321. <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    322. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;tiny_imagenet_dataset_params&quot;</span><span class="p">,</span>
    323. <span class="p">):</span>
    324. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    325. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
    326. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
    327. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    328. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    329. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    330. <span class="p">)</span></div>
    331. <div class="viewcode-block" id="tiny_imagenet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.tiny_imagenet_val">[docs]</a><span class="k">def</span> <span class="nf">tiny_imagenet_val</span><span class="p">(</span>
    332. <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    333. <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    334. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;tiny_imagenet_dataset_params&quot;</span><span class="p">,</span>
    335. <span class="p">):</span>
    336. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    337. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
    338. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
    339. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    340. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    341. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    342. <span class="p">)</span></div>
    343. <div class="viewcode-block" id="cifar10_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar10_train">[docs]</a><span class="k">def</span> <span class="nf">cifar10_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    344. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    345. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar10_dataset_params&quot;</span><span class="p">,</span>
    346. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar10</span><span class="p">,</span>
    347. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    348. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    349. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    350. <span class="p">)</span></div>
    351. <div class="viewcode-block" id="cifar10_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar10_val">[docs]</a><span class="k">def</span> <span class="nf">cifar10_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    352. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    353. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar10_dataset_params&quot;</span><span class="p">,</span>
    354. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar10</span><span class="p">,</span>
    355. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    356. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    357. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    358. <span class="p">)</span></div>
    359. <div class="viewcode-block" id="cifar100_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar100_train">[docs]</a><span class="k">def</span> <span class="nf">cifar100_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    360. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    361. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar100_dataset_params&quot;</span><span class="p">,</span>
    362. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar100</span><span class="p">,</span>
    363. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    364. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    365. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    366. <span class="p">)</span></div>
    367. <div class="viewcode-block" id="cifar100_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar100_val">[docs]</a><span class="k">def</span> <span class="nf">cifar100_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    368. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    369. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar100_dataset_params&quot;</span><span class="p">,</span>
    370. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar100</span><span class="p">,</span>
    371. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    372. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    373. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    374. <span class="p">)</span></div>
    375. <span class="k">def</span> <span class="nf">classification_test_dataloader</span><span class="p">(</span><span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="n">dataset_size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
    376. <span class="n">dataset_size</span> <span class="o">=</span> <span class="n">dataset_size</span> <span class="ow">or</span> <span class="n">batch_size</span>
    377. <span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
    378. <span class="n">ground_truth</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">)))</span>
    379. <span class="n">dataset</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">ground_truth</span><span class="p">)</span>
    380. <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
    381. <span class="k">def</span> <span class="nf">detection_test_dataloader</span><span class="p">(</span><span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">320</span><span class="p">,</span> <span class="n">dataset_size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
    382. <span class="n">dataset_size</span> <span class="o">=</span> <span class="n">dataset_size</span> <span class="ow">or</span> <span class="n">batch_size</span>
    383. <span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
    384. <span class="n">ground_truth</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">6</span><span class="p">)))</span>
    385. <span class="n">dataset</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">ground_truth</span><span class="p">)</span>
    386. <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
    387. <span class="k">def</span> <span class="nf">segmentation_test_dataloader</span><span class="p">(</span><span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span><span class="p">,</span> <span class="n">dataset_size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
    388. <span class="n">dataset_size</span> <span class="o">=</span> <span class="n">dataset_size</span> <span class="ow">or</span> <span class="n">batch_size</span>
    389. <span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
    390. <span class="n">ground_truth</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
    391. <span class="n">dataset</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">ground_truth</span><span class="p">)</span>
    392. <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
    393. <div class="viewcode-block" id="cityscapes_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    394. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    395. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_dataset_params&quot;</span><span class="p">,</span>
    396. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    397. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    398. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    399. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    400. <span class="p">)</span></div>
    401. <div class="viewcode-block" id="cityscapes_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    402. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    403. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_dataset_params&quot;</span><span class="p">,</span>
    404. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    405. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    406. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    407. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    408. <span class="p">)</span></div>
    409. <div class="viewcode-block" id="cityscapes_stdc_seg50_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg50_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg50_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    410. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    411. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg50_dataset_params&quot;</span><span class="p">,</span>
    412. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    413. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    414. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    415. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    416. <span class="p">)</span></div>
    417. <div class="viewcode-block" id="cityscapes_stdc_seg50_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg50_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg50_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    418. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    419. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg50_dataset_params&quot;</span><span class="p">,</span>
    420. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    421. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    422. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    423. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    424. <span class="p">)</span></div>
    425. <div class="viewcode-block" id="cityscapes_stdc_seg75_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg75_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg75_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    426. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    427. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg75_dataset_params&quot;</span><span class="p">,</span>
    428. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    429. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    430. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    431. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    432. <span class="p">)</span></div>
    433. <div class="viewcode-block" id="cityscapes_stdc_seg75_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg75_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg75_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    434. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    435. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg75_dataset_params&quot;</span><span class="p">,</span>
    436. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    437. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    438. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    439. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    440. <span class="p">)</span></div>
    441. <div class="viewcode-block" id="cityscapes_regseg48_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_regseg48_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_regseg48_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    442. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    443. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_regseg48_dataset_params&quot;</span><span class="p">,</span>
    444. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    445. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    446. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    447. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    448. <span class="p">)</span></div>
    449. <div class="viewcode-block" id="cityscapes_regseg48_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_regseg48_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_regseg48_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    450. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    451. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_regseg48_dataset_params&quot;</span><span class="p">,</span>
    452. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    453. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    454. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    455. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    456. <span class="p">)</span></div>
    457. <div class="viewcode-block" id="cityscapes_ddrnet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_ddrnet_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_ddrnet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    458. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    459. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_ddrnet_dataset_params&quot;</span><span class="p">,</span>
    460. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    461. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    462. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    463. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    464. <span class="p">)</span></div>
    465. <div class="viewcode-block" id="cityscapes_ddrnet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_ddrnet_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_ddrnet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    466. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    467. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_ddrnet_dataset_params&quot;</span><span class="p">,</span>
    468. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
    469. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    470. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    471. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    472. <span class="p">)</span></div>
    473. <div class="viewcode-block" id="coco_segmentation_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco_segmentation_train">[docs]</a><span class="k">def</span> <span class="nf">coco_segmentation_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    474. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    475. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_segmentation_dataset_params&quot;</span><span class="p">,</span>
    476. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CoCoSegmentationDataSet</span><span class="p">,</span>
    477. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    478. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    479. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    480. <span class="p">)</span></div>
    481. <div class="viewcode-block" id="coco_segmentation_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco_segmentation_val">[docs]</a><span class="k">def</span> <span class="nf">coco_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    482. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    483. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_segmentation_dataset_params&quot;</span><span class="p">,</span>
    484. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CoCoSegmentationDataSet</span><span class="p">,</span>
    485. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    486. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    487. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    488. <span class="p">)</span></div>
    489. <div class="viewcode-block" id="pascal_aug_segmentation_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_aug_segmentation_train">[docs]</a><span class="k">def</span> <span class="nf">pascal_aug_segmentation_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    490. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    491. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_aug_segmentation_dataset_params&quot;</span><span class="p">,</span>
    492. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOCAndAUGUnifiedDataset</span><span class="p">,</span>
    493. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    494. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    495. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    496. <span class="p">)</span></div>
    497. <div class="viewcode-block" id="pascal_aug_segmentation_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_aug_segmentation_val">[docs]</a><span class="k">def</span> <span class="nf">pascal_aug_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    498. <span class="k">return</span> <span class="n">pascal_voc_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">)</span></div>
    499. <div class="viewcode-block" id="pascal_voc_segmentation_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_segmentation_train">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_segmentation_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    500. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    501. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_segmentation_dataset_params&quot;</span><span class="p">,</span>
    502. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span>
    503. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    504. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    505. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    506. <span class="p">)</span></div>
    507. <div class="viewcode-block" id="pascal_voc_segmentation_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_segmentation_val">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    508. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    509. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_segmentation_dataset_params&quot;</span><span class="p">,</span>
    510. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span>
    511. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    512. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    513. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    514. <span class="p">)</span></div>
    515. <div class="viewcode-block" id="supervisely_persons_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.supervisely_persons_train">[docs]</a><span class="k">def</span> <span class="nf">supervisely_persons_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    516. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    517. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;supervisely_persons_dataset_params&quot;</span><span class="p">,</span>
    518. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">SuperviselyPersonsDataset</span><span class="p">,</span>
    519. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    520. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    521. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    522. <span class="p">)</span></div>
    523. <div class="viewcode-block" id="supervisely_persons_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.supervisely_persons_val">[docs]</a><span class="k">def</span> <span class="nf">supervisely_persons_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    524. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    525. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;supervisely_persons_dataset_params&quot;</span><span class="p">,</span>
    526. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">SuperviselyPersonsDataset</span><span class="p">,</span>
    527. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    528. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    529. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    530. <span class="p">)</span></div>
    531. <div class="viewcode-block" id="pascal_voc_detection_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_detection_train">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_detection_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    532. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    533. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_detection_dataset_params&quot;</span><span class="p">,</span>
    534. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">,</span>
    535. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    536. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    537. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    538. <span class="p">)</span></div>
    539. <div class="viewcode-block" id="pascal_voc_detection_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_detection_val">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_detection_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    540. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
    541. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_detection_dataset_params&quot;</span><span class="p">,</span>
    542. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOCDetectionDataset</span><span class="p">,</span>
    543. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    544. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    545. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
    546. <span class="p">)</span></div>
    547. <span class="n">ALL_DATALOADERS</span> <span class="o">=</span> <span class="p">{</span>
    548. <span class="s2">&quot;coco2017_train&quot;</span><span class="p">:</span> <span class="n">coco2017_train</span><span class="p">,</span>
    549. <span class="s2">&quot;coco2017_val&quot;</span><span class="p">:</span> <span class="n">coco2017_val</span><span class="p">,</span>
    550. <span class="s2">&quot;coco2017_train_yolox&quot;</span><span class="p">:</span> <span class="n">coco2017_train_yolox</span><span class="p">,</span>
    551. <span class="s2">&quot;coco2017_val_yolox&quot;</span><span class="p">:</span> <span class="n">coco2017_val_yolox</span><span class="p">,</span>
    552. <span class="s2">&quot;coco2017_train_ssd_lite_mobilenet_v2&quot;</span><span class="p">:</span> <span class="n">coco2017_train_ssd_lite_mobilenet_v2</span><span class="p">,</span>
    553. <span class="s2">&quot;coco2017_val_ssd_lite_mobilenet_v2&quot;</span><span class="p">:</span> <span class="n">coco2017_val_ssd_lite_mobilenet_v2</span><span class="p">,</span>
    554. <span class="s2">&quot;imagenet_train&quot;</span><span class="p">:</span> <span class="n">imagenet_train</span><span class="p">,</span>
    555. <span class="s2">&quot;imagenet_val&quot;</span><span class="p">:</span> <span class="n">imagenet_val</span><span class="p">,</span>
    556. <span class="s2">&quot;imagenet_efficientnet_train&quot;</span><span class="p">:</span> <span class="n">imagenet_efficientnet_train</span><span class="p">,</span>
    557. <span class="s2">&quot;imagenet_efficientnet_val&quot;</span><span class="p">:</span> <span class="n">imagenet_efficientnet_val</span><span class="p">,</span>
    558. <span class="s2">&quot;imagenet_mobilenetv2_train&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv2_train</span><span class="p">,</span>
    559. <span class="s2">&quot;imagenet_mobilenetv2_val&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv2_val</span><span class="p">,</span>
    560. <span class="s2">&quot;imagenet_mobilenetv3_train&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv3_train</span><span class="p">,</span>
    561. <span class="s2">&quot;imagenet_mobilenetv3_val&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv3_val</span><span class="p">,</span>
    562. <span class="s2">&quot;imagenet_regnetY_train&quot;</span><span class="p">:</span> <span class="n">imagenet_regnetY_train</span><span class="p">,</span>
    563. <span class="s2">&quot;imagenet_regnetY_val&quot;</span><span class="p">:</span> <span class="n">imagenet_regnetY_val</span><span class="p">,</span>
    564. <span class="s2">&quot;imagenet_resnet50_train&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_train</span><span class="p">,</span>
    565. <span class="s2">&quot;imagenet_resnet50_val&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_val</span><span class="p">,</span>
    566. <span class="s2">&quot;imagenet_resnet50_kd_train&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_kd_train</span><span class="p">,</span>
    567. <span class="s2">&quot;imagenet_resnet50_kd_val&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_kd_val</span><span class="p">,</span>
    568. <span class="s2">&quot;imagenet_vit_base_train&quot;</span><span class="p">:</span> <span class="n">imagenet_vit_base_train</span><span class="p">,</span>
    569. <span class="s2">&quot;imagenet_vit_base_val&quot;</span><span class="p">:</span> <span class="n">imagenet_vit_base_val</span><span class="p">,</span>
    570. <span class="s2">&quot;tiny_imagenet_train&quot;</span><span class="p">:</span> <span class="n">tiny_imagenet_train</span><span class="p">,</span>
    571. <span class="s2">&quot;tiny_imagenet_val&quot;</span><span class="p">:</span> <span class="n">tiny_imagenet_val</span><span class="p">,</span>
    572. <span class="s2">&quot;cifar10_train&quot;</span><span class="p">:</span> <span class="n">cifar10_train</span><span class="p">,</span>
    573. <span class="s2">&quot;cifar10_val&quot;</span><span class="p">:</span> <span class="n">cifar10_val</span><span class="p">,</span>
    574. <span class="s2">&quot;cifar100_train&quot;</span><span class="p">:</span> <span class="n">cifar100_train</span><span class="p">,</span>
    575. <span class="s2">&quot;cifar100_val&quot;</span><span class="p">:</span> <span class="n">cifar100_val</span><span class="p">,</span>
    576. <span class="s2">&quot;cityscapes_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_train</span><span class="p">,</span>
    577. <span class="s2">&quot;cityscapes_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_val</span><span class="p">,</span>
    578. <span class="s2">&quot;cityscapes_stdc_seg50_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg50_train</span><span class="p">,</span>
    579. <span class="s2">&quot;cityscapes_stdc_seg50_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg50_val</span><span class="p">,</span>
    580. <span class="s2">&quot;cityscapes_stdc_seg75_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg75_train</span><span class="p">,</span>
    581. <span class="s2">&quot;cityscapes_stdc_seg75_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg75_val</span><span class="p">,</span>
    582. <span class="s2">&quot;cityscapes_regseg48_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_regseg48_train</span><span class="p">,</span>
    583. <span class="s2">&quot;cityscapes_regseg48_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_regseg48_val</span><span class="p">,</span>
    584. <span class="s2">&quot;cityscapes_ddrnet_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_ddrnet_train</span><span class="p">,</span>
    585. <span class="s2">&quot;cityscapes_ddrnet_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_ddrnet_val</span><span class="p">,</span>
    586. <span class="s2">&quot;coco_segmentation_train&quot;</span><span class="p">:</span> <span class="n">coco_segmentation_train</span><span class="p">,</span>
    587. <span class="s2">&quot;coco_segmentation_val&quot;</span><span class="p">:</span> <span class="n">coco_segmentation_val</span><span class="p">,</span>
    588. <span class="s2">&quot;pascal_aug_segmentation_train&quot;</span><span class="p">:</span> <span class="n">pascal_aug_segmentation_train</span><span class="p">,</span>
    589. <span class="s2">&quot;pascal_aug_segmentation_val&quot;</span><span class="p">:</span> <span class="n">pascal_aug_segmentation_val</span><span class="p">,</span>
    590. <span class="s2">&quot;pascal_voc_segmentation_train&quot;</span><span class="p">:</span> <span class="n">pascal_voc_segmentation_train</span><span class="p">,</span>
    591. <span class="s2">&quot;pascal_voc_segmentation_val&quot;</span><span class="p">:</span> <span class="n">pascal_voc_segmentation_val</span><span class="p">,</span>
    592. <span class="s2">&quot;supervisely_persons_train&quot;</span><span class="p">:</span> <span class="n">supervisely_persons_train</span><span class="p">,</span>
    593. <span class="s2">&quot;supervisely_persons_val&quot;</span><span class="p">:</span> <span class="n">supervisely_persons_val</span><span class="p">,</span>
    594. <span class="s2">&quot;pascal_voc_detection_train&quot;</span><span class="p">:</span> <span class="n">pascal_voc_detection_train</span><span class="p">,</span>
    595. <span class="s2">&quot;pascal_voc_detection_val&quot;</span><span class="p">:</span> <span class="n">pascal_voc_detection_val</span><span class="p">,</span>
    596. <span class="p">}</span>
    597. <div class="viewcode-block" id="get"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.get">[docs]</a><span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
    598. <span class="sd">&quot;&quot;&quot;</span>
    599. <span class="sd"> Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.</span>
    600. <span class="sd"> :param name: dataset name in ALL_DATALOADERS.</span>
    601. <span class="sd"> :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.</span>
    602. <span class="sd"> :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__</span>
    603. <span class="sd"> :param dataset: torch.utils.data.Dataset to be used instead of passing &quot;name&quot; (i.e for external dataset objects).</span>
    604. <span class="sd"> :return: initialized DataLoader.</span>
    605. <span class="sd"> &quot;&quot;&quot;</span>
    606. <span class="k">if</span> <span class="n">dataset</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    607. <span class="k">if</span> <span class="n">name</span> <span class="ow">or</span> <span class="n">dataset_params</span><span class="p">:</span>
    608. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;&#39;name&#39; and &#39;dataset_params&#39; cannot be passed with initialized dataset.&quot;</span><span class="p">)</span>
    609. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_process_sampler_params</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="p">{})</span>
    610. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="o">**</span><span class="n">dataloader_params</span><span class="p">)</span>
    611. <span class="k">elif</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">ALL_DATALOADERS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    612. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported dataloader: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
    613. <span class="k">else</span><span class="p">:</span>
    614. <span class="n">dataloader_cls</span> <span class="o">=</span> <span class="n">ALL_DATALOADERS</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
    615. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">dataloader_cls</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">)</span>
    616. <span class="k">return</span> <span class="n">dataloader</span></div>
    617. </pre></div>
    618. </div>
    619. </div>
    620. <footer>
    621. <hr/>
    622. <div role="contentinfo">
    623. <p>&#169; Copyright 2021, SuperGradients team.</p>
    624. </div>
    625. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    626. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    627. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    628. </footer>
    629. </div>
    630. </div>
    631. </section>
    632. </div>
    633. <script>
    634. jQuery(function () {
    635. SphinxRtdTheme.Navigation.enable(true);
    636. });
    637. </script>
    638. </body>
    639. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.all_datasets &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.datasets.all_datasets</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.datasets.all_datasets</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">defaultdict</span>
    84. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Type</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.dataset_interfaces</span> <span class="kn">import</span> <span class="n">DatasetInterface</span><span class="p">,</span> <span class="n">TestDatasetInterface</span><span class="p">,</span> \
    86. <span class="n">LibraryDatasetInterface</span><span class="p">,</span> \
    87. <span class="n">ClassificationDatasetInterface</span><span class="p">,</span> <span class="n">Cifar10DatasetInterface</span><span class="p">,</span> <span class="n">Cifar100DatasetInterface</span><span class="p">,</span> \
    88. <span class="n">ImageNetDatasetInterface</span><span class="p">,</span> <span class="n">TinyImageNetDatasetInterface</span><span class="p">,</span> <span class="n">CoCoSegmentationDatasetInterface</span><span class="p">,</span>\
    89. <span class="n">PascalAUG2012SegmentationDataSetInterface</span><span class="p">,</span> <span class="n">PascalVOC2012SegmentationDataSetInterface</span>
    90. <span class="kn">from</span> <span class="nn">super_gradients.common.data_types.enum.deep_learning_task</span> <span class="kn">import</span> <span class="n">DeepLearningTask</span>
    91. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.dataset_interfaces.dataset_interface</span> <span class="kn">import</span> <span class="n">CoCoDetectionDatasetInterface</span>
    92. <span class="n">CLASSIFICATION_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
    93. <span class="s2">&quot;test_dataset&quot;</span><span class="p">:</span> <span class="n">TestDatasetInterface</span><span class="p">,</span>
    94. <span class="s2">&quot;library_dataset&quot;</span><span class="p">:</span> <span class="n">LibraryDatasetInterface</span><span class="p">,</span>
    95. <span class="s2">&quot;classification_dataset&quot;</span><span class="p">:</span> <span class="n">ClassificationDatasetInterface</span><span class="p">,</span>
    96. <span class="s2">&quot;cifar_10&quot;</span><span class="p">:</span> <span class="n">Cifar10DatasetInterface</span><span class="p">,</span>
    97. <span class="s2">&quot;cifar_100&quot;</span><span class="p">:</span> <span class="n">Cifar100DatasetInterface</span><span class="p">,</span>
    98. <span class="s2">&quot;imagenet&quot;</span><span class="p">:</span> <span class="n">ImageNetDatasetInterface</span><span class="p">,</span>
    99. <span class="s2">&quot;tiny_imagenet&quot;</span><span class="p">:</span> <span class="n">TinyImageNetDatasetInterface</span>
    100. <span class="p">}</span>
    101. <span class="n">OBJECT_DETECTION_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
    102. <span class="s2">&quot;coco&quot;</span><span class="p">:</span> <span class="n">CoCoDetectionDatasetInterface</span><span class="p">,</span>
    103. <span class="p">}</span>
    104. <span class="n">SEMANTIC_SEGMENTATION_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
    105. <span class="s2">&quot;coco&quot;</span><span class="p">:</span> <span class="n">CoCoSegmentationDatasetInterface</span><span class="p">,</span>
    106. <span class="s2">&quot;pascal_voc&quot;</span><span class="p">:</span> <span class="n">PascalVOC2012SegmentationDataSetInterface</span><span class="p">,</span>
    107. <span class="s2">&quot;pascal_aug&quot;</span><span class="p">:</span> <span class="n">PascalAUG2012SegmentationDataSetInterface</span>
    108. <span class="p">}</span>
    109. <div class="viewcode-block" id="DataSetDoesNotExistException"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.DataSetDoesNotExistException">[docs]</a><span class="k">class</span> <span class="nc">DataSetDoesNotExistException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    110. <span class="sd">&quot;&quot;&quot;</span>
    111. <span class="sd"> The requested dataset does not exist, or is not implemented.</span>
    112. <span class="sd"> &quot;&quot;&quot;</span>
    113. <span class="k">pass</span></div>
    114. <div class="viewcode-block" id="SgLibraryDatasets"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.SgLibraryDatasets">[docs]</a><span class="k">class</span> <span class="nc">SgLibraryDatasets</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    115. <span class="sd">&quot;&quot;&quot;</span>
    116. <span class="sd"> Holds all of the different library dataset dictionaries, by DL Task mapping</span>
    117. <span class="sd"> Attributes:</span>
    118. <span class="sd"> CLASSIFICATION Dictionary of Classification Data sets</span>
    119. <span class="sd"> OBJECT_DETECTION Dictionary of Object Detection Data sets</span>
    120. <span class="sd"> SEMANTIC_SEGMENTATION Dictionary of Semantic Segmentation Data sets</span>
    121. <span class="sd"> &quot;&quot;&quot;</span>
    122. <span class="n">CLASSIFICATION</span> <span class="o">=</span> <span class="n">CLASSIFICATION_DATASETS</span>
    123. <span class="n">OBJECT_DETECTION</span> <span class="o">=</span> <span class="n">OBJECT_DETECTION_DATASETS</span>
    124. <span class="n">SEMANTIC_SEGMENTATION</span> <span class="o">=</span> <span class="n">SEMANTIC_SEGMENTATION_DATASETS</span>
    125. <span class="n">_datasets_mapping</span> <span class="o">=</span> <span class="p">{</span>
    126. <span class="n">DeepLearningTask</span><span class="o">.</span><span class="n">CLASSIFICATION</span><span class="p">:</span> <span class="n">CLASSIFICATION</span><span class="p">,</span>
    127. <span class="n">DeepLearningTask</span><span class="o">.</span><span class="n">SEMANTIC_SEGMENTATION</span><span class="p">:</span> <span class="n">SEMANTIC_SEGMENTATION</span><span class="p">,</span>
    128. <span class="n">DeepLearningTask</span><span class="o">.</span><span class="n">OBJECT_DETECTION</span><span class="p">:</span> <span class="n">OBJECT_DETECTION</span><span class="p">,</span>
    129. <span class="p">}</span>
    130. <div class="viewcode-block" id="SgLibraryDatasets.get_all_available_datasets"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.SgLibraryDatasets.get_all_available_datasets">[docs]</a> <span class="nd">@staticmethod</span>
    131. <span class="k">def</span> <span class="nf">get_all_available_datasets</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]:</span>
    132. <span class="sd">&quot;&quot;&quot;</span>
    133. <span class="sd"> Gets all the available datasets.</span>
    134. <span class="sd"> &quot;&quot;&quot;</span>
    135. <span class="n">all_datasets</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">list</span><span class="p">)</span>
    136. <span class="k">for</span> <span class="n">dl_task</span><span class="p">,</span> <span class="n">task_datasets</span> <span class="ow">in</span> <span class="n">SgLibraryDatasets</span><span class="o">.</span><span class="n">_datasets_mapping</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    137. <span class="k">for</span> <span class="n">dataset_name</span><span class="p">,</span> <span class="n">dataset_interface</span> <span class="ow">in</span> <span class="n">task_datasets</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    138. <span class="n">all_datasets</span><span class="p">[</span><span class="n">dl_task</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span>
    139. <span class="c1"># TODO: Return Dataset Metadata list from the dataset interfaces objects</span>
    140. <span class="c1"># TODO: Transform DatasetInterface -&gt; DataSetMetadata</span>
    141. <span class="k">return</span> <span class="n">all_datasets</span></div>
    142. <div class="viewcode-block" id="SgLibraryDatasets.get_dataset"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.SgLibraryDatasets.get_dataset">[docs]</a> <span class="nd">@staticmethod</span>
    143. <span class="k">def</span> <span class="nf">get_dataset</span><span class="p">(</span><span class="n">dl_task</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Type</span><span class="p">[</span><span class="n">DatasetInterface</span><span class="p">]:</span>
    144. <span class="sd">&quot;&quot;&quot;</span>
    145. <span class="sd"> Get&#39;s a dataset with a given name for a given deep learning task.</span>
    146. <span class="sd"> examp:</span>
    147. <span class="sd"> &gt;&gt;&gt; SgLibraryDatasets.get_dataset(dl_task=&#39;classification&#39;, dataset_name=&#39;cifar_100&#39;)</span>
    148. <span class="sd"> &gt;&gt;&gt; &lt;Cifar100DatasetInterface instance&gt;</span>
    149. <span class="sd"> &quot;&quot;&quot;</span>
    150. <span class="n">task_datasets</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">DatasetInterface</span><span class="p">]</span> <span class="o">=</span> <span class="n">SgLibraryDatasets</span><span class="o">.</span><span class="n">_datasets_mapping</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">dl_task</span><span class="p">)</span>
    151. <span class="k">if</span> <span class="ow">not</span> <span class="n">task_datasets</span><span class="p">:</span>
    152. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid Deep Learining Task: </span><span class="si">{</span><span class="n">dl_task</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    153. <span class="n">dataset</span><span class="p">:</span> <span class="n">DatasetInterface</span> <span class="o">=</span> <span class="n">task_datasets</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span>
    154. <span class="k">if</span> <span class="ow">not</span> <span class="n">dataset</span><span class="p">:</span>
    155. <span class="k">raise</span> <span class="n">DataSetDoesNotExistException</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span>
    156. <span class="k">return</span> <span class="n">dataset</span></div></div>
    157. </pre></div>
    158. </div>
    159. </div>
    160. <footer>
    161. <hr/>
    162. <div role="contentinfo">
    163. <p>&#169; Copyright 2021, SuperGradients team.</p>
    164. </div>
    165. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    166. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    167. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    168. </footer>
    169. </div>
    170. </div>
    171. </section>
    172. </div>
    173. <script>
    174. jQuery(function () {
    175. SphinxRtdTheme.Navigation.enable(true);
    176. });
    177. </script>
    178. </body>
    179. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.auto_augment &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.datasets.auto_augment</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.datasets.auto_augment</h1><div class="highlight"><pre>
    83. <span></span><span class="sd">&quot;&quot;&quot; RandAugment</span>
    84. <span class="sd">RandAugment is a variant of AutoAugment which randomly selects transformations</span>
    85. <span class="sd"> from AutoAugment to be applied on an image.</span>
    86. <span class="sd">RandomAugmentation Implementation adapted from:</span>
    87. <span class="sd"> https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py</span>
    88. <span class="sd">Papers:</span>
    89. <span class="sd"> RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719</span>
    90. <span class="sd">&quot;&quot;&quot;</span>
    91. <span class="kn">import</span> <span class="nn">random</span>
    92. <span class="kn">import</span> <span class="nn">re</span>
    93. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageOps</span><span class="p">,</span> <span class="n">ImageEnhance</span>
    94. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    95. <span class="n">_FILL</span> <span class="o">=</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span>
    96. <span class="c1"># to unify the calls of the different augmentations in terms of params, all augmentations are set to work with a single</span>
    97. <span class="c1"># magnitude params, normalized according to _MAX_MAGNITUDE</span>
    98. <span class="n">_MAX_MAGNITUDE</span> <span class="o">=</span> <span class="mf">10.</span>
    99. <span class="n">_HPARAMS_DEFAULT</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
    100. <span class="n">translate_const</span><span class="o">=</span><span class="mi">250</span><span class="p">,</span>
    101. <span class="n">img_mean</span><span class="o">=</span><span class="n">_FILL</span><span class="p">,</span>
    102. <span class="p">)</span>
    103. <span class="c1"># Define the interpolation types</span>
    104. <span class="n">_RANDOM_INTERPOLATION</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span>
    105. <span class="k">def</span> <span class="nf">_interpolation</span><span class="p">(</span><span class="n">kwargs</span><span class="p">):</span>
    106. <span class="sd">&quot;&quot;&quot;</span>
    107. <span class="sd"> Performs Bi-Linear interpolation</span>
    108. <span class="sd"> &quot;&quot;&quot;</span>
    109. <span class="n">interpolation</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;resample&#39;</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span><span class="p">)</span>
    110. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">interpolation</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
    111. <span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">interpolation</span><span class="p">)</span>
    112. <span class="k">else</span><span class="p">:</span>
    113. <span class="k">return</span> <span class="n">interpolation</span>
    114. <div class="viewcode-block" id="shear_x"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.shear_x">[docs]</a><span class="k">def</span> <span class="nf">shear_x</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    115. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    116. <div class="viewcode-block" id="shear_y"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.shear_y">[docs]</a><span class="k">def</span> <span class="nf">shear_y</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    117. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    118. <div class="viewcode-block" id="translate_x_rel"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_x_rel">[docs]</a><span class="k">def</span> <span class="nf">translate_x_rel</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pct</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    119. <span class="n">pixels</span> <span class="o">=</span> <span class="n">pct</span> <span class="o">*</span> <span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    120. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    121. <div class="viewcode-block" id="translate_y_rel"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_y_rel">[docs]</a><span class="k">def</span> <span class="nf">translate_y_rel</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pct</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    122. <span class="n">pixels</span> <span class="o">=</span> <span class="n">pct</span> <span class="o">*</span> <span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    123. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">pixels</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    124. <div class="viewcode-block" id="translate_x_abs"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_x_abs">[docs]</a><span class="k">def</span> <span class="nf">translate_x_abs</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    125. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    126. <div class="viewcode-block" id="translate_y_abs"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_y_abs">[docs]</a><span class="k">def</span> <span class="nf">translate_y_abs</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    127. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">pixels</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    128. <div class="viewcode-block" id="rotate"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.rotate">[docs]</a><span class="k">def</span> <span class="nf">rotate</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">degrees</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    129. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">degrees</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    130. <div class="viewcode-block" id="auto_contrast"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.auto_contrast">[docs]</a><span class="k">def</span> <span class="nf">auto_contrast</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    131. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">autocontrast</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></div>
    132. <div class="viewcode-block" id="invert"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.invert">[docs]</a><span class="k">def</span> <span class="nf">invert</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    133. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">invert</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></div>
    134. <div class="viewcode-block" id="equalize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.equalize">[docs]</a><span class="k">def</span> <span class="nf">equalize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    135. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">equalize</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></div>
    136. <div class="viewcode-block" id="solarize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.solarize">[docs]</a><span class="k">def</span> <span class="nf">solarize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">thresh</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    137. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">solarize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">thresh</span><span class="p">)</span></div>
    138. <div class="viewcode-block" id="solarize_add"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.solarize_add">[docs]</a><span class="k">def</span> <span class="nf">solarize_add</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">add</span><span class="p">,</span> <span class="n">thresh</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    139. <span class="n">lut</span> <span class="o">=</span> <span class="p">[]</span>
    140. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">256</span><span class="p">):</span>
    141. <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">thresh</span><span class="p">:</span>
    142. <span class="n">lut</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="mi">255</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="n">add</span><span class="p">))</span>
    143. <span class="k">else</span><span class="p">:</span>
    144. <span class="n">lut</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
    145. <span class="k">if</span> <span class="n">img</span><span class="o">.</span><span class="n">mode</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">&quot;L&quot;</span><span class="p">,</span> <span class="s2">&quot;RGB&quot;</span><span class="p">):</span>
    146. <span class="k">if</span> <span class="n">img</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;RGB&quot;</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">lut</span><span class="p">)</span> <span class="o">==</span> <span class="mi">256</span><span class="p">:</span>
    147. <span class="n">lut</span> <span class="o">=</span> <span class="n">lut</span> <span class="o">+</span> <span class="n">lut</span> <span class="o">+</span> <span class="n">lut</span>
    148. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">point</span><span class="p">(</span><span class="n">lut</span><span class="p">)</span>
    149. <span class="k">else</span><span class="p">:</span>
    150. <span class="k">return</span> <span class="n">img</span></div>
    151. <div class="viewcode-block" id="posterize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.posterize">[docs]</a><span class="k">def</span> <span class="nf">posterize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">bits_to_keep</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    152. <span class="k">if</span> <span class="n">bits_to_keep</span> <span class="o">&gt;=</span> <span class="mi">8</span><span class="p">:</span>
    153. <span class="k">return</span> <span class="n">img</span>
    154. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">posterize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">bits_to_keep</span><span class="p">)</span></div>
    155. <div class="viewcode-block" id="contrast"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.contrast">[docs]</a><span class="k">def</span> <span class="nf">contrast</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    156. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Contrast</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
    157. <div class="viewcode-block" id="color"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.color">[docs]</a><span class="k">def</span> <span class="nf">color</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    158. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Color</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
    159. <div class="viewcode-block" id="brightness"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.brightness">[docs]</a><span class="k">def</span> <span class="nf">brightness</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    160. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Brightness</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
    161. <div class="viewcode-block" id="sharpness"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.sharpness">[docs]</a><span class="k">def</span> <span class="nf">sharpness</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
    162. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Sharpness</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
    163. <span class="k">def</span> <span class="nf">_randomly_negate</span><span class="p">(</span><span class="n">v</span><span class="p">):</span>
    164. <span class="sd">&quot;&quot;&quot;With 50% prob, negate the value&quot;&quot;&quot;</span>
    165. <span class="k">return</span> <span class="o">-</span><span class="n">v</span> <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.5</span> <span class="k">else</span> <span class="n">v</span>
    166. <span class="k">def</span> <span class="nf">_rotate_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    167. <span class="c1"># range [-30, 30]</span>
    168. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">30.</span>
    169. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    170. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
    171. <span class="k">def</span> <span class="nf">_enhance_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    172. <span class="c1"># range [0.1, 1.9]</span>
    173. <span class="k">return</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.8</span> <span class="o">+</span> <span class="mf">0.1</span><span class="p">,</span>
    174. <span class="k">def</span> <span class="nf">_enhance_increasing_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    175. <span class="c1"># range [0.1, 1.9]</span>
    176. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">.9</span>
    177. <span class="n">level</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">+</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    178. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
    179. <span class="k">def</span> <span class="nf">_shear_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    180. <span class="c1"># range [-0.3, 0.3]</span>
    181. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.3</span>
    182. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    183. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
    184. <span class="k">def</span> <span class="nf">_translate_abs_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
    185. <span class="n">translate_const</span> <span class="o">=</span> <span class="n">hparams</span><span class="p">[</span><span class="s1">&#39;translate_const&#39;</span><span class="p">]</span>
    186. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="n">translate_const</span><span class="p">)</span>
    187. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    188. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
    189. <span class="k">def</span> <span class="nf">_translate_rel_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
    190. <span class="c1"># default range [-0.45, 0.45]</span>
    191. <span class="n">translate_pct</span> <span class="o">=</span> <span class="n">hparams</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;translate_pct&#39;</span><span class="p">,</span> <span class="mf">0.45</span><span class="p">)</span>
    192. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="n">translate_pct</span>
    193. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
    194. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
    195. <span class="k">def</span> <span class="nf">_posterize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    196. <span class="c1"># As per Tensorflow TPU EfficientNet impl</span>
    197. <span class="c1"># range [0, 4], &#39;keep 0 up to 4 MSB of original image&#39;</span>
    198. <span class="c1"># intensity/severity of augmentation decreases with level</span>
    199. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span><span class="p">),</span>
    200. <span class="k">def</span> <span class="nf">_posterize_increasing_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
    201. <span class="c1"># As per Tensorflow models research and UDA impl</span>
    202. <span class="c1"># range [4, 0], &#39;keep 4 down to 0 MSB of original image&#39;,</span>
    203. <span class="c1"># intensity/severity of augmentation increases with level</span>
    204. <span class="k">return</span> <span class="mi">4</span> <span class="o">-</span> <span class="n">_posterize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
    205. <span class="k">def</span> <span class="nf">_posterize_original_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    206. <span class="c1"># As per original AutoAugment paper description</span>
    207. <span class="c1"># range [4, 8], &#39;keep 4 up to 8 MSB of image&#39;</span>
    208. <span class="c1"># intensity/severity of augmentation decreases with level</span>
    209. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span> <span class="o">+</span> <span class="mi">4</span><span class="p">,</span>
    210. <span class="k">def</span> <span class="nf">_solarize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    211. <span class="c1"># range [0, 256]</span>
    212. <span class="c1"># intensity/severity of augmentation decreases with level</span>
    213. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">256</span><span class="p">),</span>
    214. <span class="k">def</span> <span class="nf">_solarize_increasing_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    215. <span class="c1"># range [0, 256]</span>
    216. <span class="c1"># intensity/severity of augmentation increases with level</span>
    217. <span class="k">return</span> <span class="mi">256</span> <span class="o">-</span> <span class="n">_solarize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
    218. <span class="k">def</span> <span class="nf">_solarize_add_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
    219. <span class="c1"># range [0, 110]</span>
    220. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">110</span><span class="p">),</span>
    221. <span class="n">LEVEL_TO_ARG</span> <span class="o">=</span> <span class="p">{</span>
    222. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    223. <span class="s1">&#39;Equalize&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    224. <span class="s1">&#39;Invert&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    225. <span class="s1">&#39;Rotate&#39;</span><span class="p">:</span> <span class="n">_rotate_level_to_arg</span><span class="p">,</span>
    226. <span class="c1"># There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers</span>
    227. <span class="s1">&#39;Posterize&#39;</span><span class="p">:</span> <span class="n">_posterize_level_to_arg</span><span class="p">,</span>
    228. <span class="s1">&#39;PosterizeIncreasing&#39;</span><span class="p">:</span> <span class="n">_posterize_increasing_level_to_arg</span><span class="p">,</span>
    229. <span class="s1">&#39;PosterizeOriginal&#39;</span><span class="p">:</span> <span class="n">_posterize_original_level_to_arg</span><span class="p">,</span>
    230. <span class="s1">&#39;Solarize&#39;</span><span class="p">:</span> <span class="n">_solarize_level_to_arg</span><span class="p">,</span>
    231. <span class="s1">&#39;SolarizeIncreasing&#39;</span><span class="p">:</span> <span class="n">_solarize_increasing_level_to_arg</span><span class="p">,</span>
    232. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">:</span> <span class="n">_solarize_add_level_to_arg</span><span class="p">,</span>
    233. <span class="s1">&#39;Color&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
    234. <span class="s1">&#39;ColorIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
    235. <span class="s1">&#39;Contrast&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
    236. <span class="s1">&#39;ContrastIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
    237. <span class="s1">&#39;Brightness&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
    238. <span class="s1">&#39;BrightnessIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
    239. <span class="s1">&#39;Sharpness&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
    240. <span class="s1">&#39;SharpnessIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
    241. <span class="s1">&#39;ShearX&#39;</span><span class="p">:</span> <span class="n">_shear_level_to_arg</span><span class="p">,</span>
    242. <span class="s1">&#39;ShearY&#39;</span><span class="p">:</span> <span class="n">_shear_level_to_arg</span><span class="p">,</span>
    243. <span class="s1">&#39;TranslateX&#39;</span><span class="p">:</span> <span class="n">_translate_abs_level_to_arg</span><span class="p">,</span>
    244. <span class="s1">&#39;TranslateY&#39;</span><span class="p">:</span> <span class="n">_translate_abs_level_to_arg</span><span class="p">,</span>
    245. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">:</span> <span class="n">_translate_rel_level_to_arg</span><span class="p">,</span>
    246. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">:</span> <span class="n">_translate_rel_level_to_arg</span><span class="p">,</span>
    247. <span class="p">}</span>
    248. <span class="n">NAME_TO_OP</span> <span class="o">=</span> <span class="p">{</span>
    249. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">:</span> <span class="n">auto_contrast</span><span class="p">,</span>
    250. <span class="s1">&#39;Equalize&#39;</span><span class="p">:</span> <span class="n">equalize</span><span class="p">,</span>
    251. <span class="s1">&#39;Invert&#39;</span><span class="p">:</span> <span class="n">invert</span><span class="p">,</span>
    252. <span class="s1">&#39;Rotate&#39;</span><span class="p">:</span> <span class="n">rotate</span><span class="p">,</span>
    253. <span class="s1">&#39;Posterize&#39;</span><span class="p">:</span> <span class="n">posterize</span><span class="p">,</span>
    254. <span class="s1">&#39;PosterizeIncreasing&#39;</span><span class="p">:</span> <span class="n">posterize</span><span class="p">,</span>
    255. <span class="s1">&#39;PosterizeOriginal&#39;</span><span class="p">:</span> <span class="n">posterize</span><span class="p">,</span>
    256. <span class="s1">&#39;Solarize&#39;</span><span class="p">:</span> <span class="n">solarize</span><span class="p">,</span>
    257. <span class="s1">&#39;SolarizeIncreasing&#39;</span><span class="p">:</span> <span class="n">solarize</span><span class="p">,</span>
    258. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">:</span> <span class="n">solarize_add</span><span class="p">,</span>
    259. <span class="s1">&#39;Color&#39;</span><span class="p">:</span> <span class="n">color</span><span class="p">,</span>
    260. <span class="s1">&#39;ColorIncreasing&#39;</span><span class="p">:</span> <span class="n">color</span><span class="p">,</span>
    261. <span class="s1">&#39;Contrast&#39;</span><span class="p">:</span> <span class="n">contrast</span><span class="p">,</span>
    262. <span class="s1">&#39;ContrastIncreasing&#39;</span><span class="p">:</span> <span class="n">contrast</span><span class="p">,</span>
    263. <span class="s1">&#39;Brightness&#39;</span><span class="p">:</span> <span class="n">brightness</span><span class="p">,</span>
    264. <span class="s1">&#39;BrightnessIncreasing&#39;</span><span class="p">:</span> <span class="n">brightness</span><span class="p">,</span>
    265. <span class="s1">&#39;Sharpness&#39;</span><span class="p">:</span> <span class="n">sharpness</span><span class="p">,</span>
    266. <span class="s1">&#39;SharpnessIncreasing&#39;</span><span class="p">:</span> <span class="n">sharpness</span><span class="p">,</span>
    267. <span class="s1">&#39;ShearX&#39;</span><span class="p">:</span> <span class="n">shear_x</span><span class="p">,</span>
    268. <span class="s1">&#39;ShearY&#39;</span><span class="p">:</span> <span class="n">shear_y</span><span class="p">,</span>
    269. <span class="s1">&#39;TranslateX&#39;</span><span class="p">:</span> <span class="n">translate_x_abs</span><span class="p">,</span>
    270. <span class="s1">&#39;TranslateY&#39;</span><span class="p">:</span> <span class="n">translate_y_abs</span><span class="p">,</span>
    271. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">:</span> <span class="n">translate_x_rel</span><span class="p">,</span>
    272. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">:</span> <span class="n">translate_y_rel</span><span class="p">,</span>
    273. <span class="p">}</span>
    274. <div class="viewcode-block" id="AugmentOp"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.AugmentOp">[docs]</a><span class="k">class</span> <span class="nc">AugmentOp</span><span class="p">:</span>
    275. <span class="sd">&quot;&quot;&quot;</span>
    276. <span class="sd"> single auto augment operations</span>
    277. <span class="sd"> &quot;&quot;&quot;</span>
    278. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    279. <span class="n">hparams</span> <span class="o">=</span> <span class="n">hparams</span> <span class="ow">or</span> <span class="n">_HPARAMS_DEFAULT</span>
    280. <span class="bp">self</span><span class="o">.</span><span class="n">aug_fn</span> <span class="o">=</span> <span class="n">NAME_TO_OP</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
    281. <span class="bp">self</span><span class="o">.</span><span class="n">level_fn</span> <span class="o">=</span> <span class="n">LEVEL_TO_ARG</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
    282. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
    283. <span class="bp">self</span><span class="o">.</span><span class="n">magnitude</span> <span class="o">=</span> <span class="n">magnitude</span>
    284. <span class="bp">self</span><span class="o">.</span><span class="n">hparams</span> <span class="o">=</span> <span class="n">hparams</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    285. <span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
    286. <span class="n">fillcolor</span><span class="o">=</span><span class="n">hparams</span><span class="p">[</span><span class="s1">&#39;img_mean&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;img_mean&#39;</span> <span class="ow">in</span> <span class="n">hparams</span> <span class="k">else</span> <span class="n">_FILL</span><span class="p">,</span>
    287. <span class="n">resample</span><span class="o">=</span><span class="n">hparams</span><span class="p">[</span><span class="s1">&#39;interpolation&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;interpolation&#39;</span> <span class="ow">in</span> <span class="n">hparams</span> <span class="k">else</span> <span class="n">_RANDOM_INTERPOLATION</span><span class="p">,</span>
    288. <span class="p">)</span>
    289. <span class="c1"># If magnitude_std is &gt; 0, introduce some randomness</span>
    290. <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hparams</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;magnitude_std&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    291. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
    292. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">&lt;</span> <span class="mf">1.0</span> <span class="ow">and</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
    293. <span class="k">return</span> <span class="n">img</span>
    294. <span class="n">magnitude</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude</span>
    295. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span><span class="p">:</span>
    296. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span> <span class="o">==</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;inf&#39;</span><span class="p">):</span>
    297. <span class="n">magnitude</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">magnitude</span><span class="p">)</span>
    298. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    299. <span class="n">magnitude</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">gauss</span><span class="p">(</span><span class="n">magnitude</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span><span class="p">)</span>
    300. <span class="n">magnitude</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">_MAX_MAGNITUDE</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">magnitude</span><span class="p">))</span> <span class="c1"># clip to valid range</span>
    301. <span class="n">level_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">level_fn</span><span class="p">(</span><span class="n">magnitude</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hparams</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">level_fn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="nb">tuple</span><span class="p">()</span>
    302. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">aug_fn</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">*</span><span class="n">level_args</span><span class="p">,</span> <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span><span class="p">)</span></div>
    303. <span class="n">_RAND_TRANSFORMS</span> <span class="o">=</span> <span class="p">[</span>
    304. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">,</span>
    305. <span class="s1">&#39;Equalize&#39;</span><span class="p">,</span>
    306. <span class="s1">&#39;Invert&#39;</span><span class="p">,</span>
    307. <span class="s1">&#39;Rotate&#39;</span><span class="p">,</span>
    308. <span class="s1">&#39;Posterize&#39;</span><span class="p">,</span>
    309. <span class="s1">&#39;Solarize&#39;</span><span class="p">,</span>
    310. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">,</span>
    311. <span class="s1">&#39;Color&#39;</span><span class="p">,</span>
    312. <span class="s1">&#39;Contrast&#39;</span><span class="p">,</span>
    313. <span class="s1">&#39;Brightness&#39;</span><span class="p">,</span>
    314. <span class="s1">&#39;Sharpness&#39;</span><span class="p">,</span>
    315. <span class="s1">&#39;ShearX&#39;</span><span class="p">,</span>
    316. <span class="s1">&#39;ShearY&#39;</span><span class="p">,</span>
    317. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">,</span>
    318. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">,</span>
    319. <span class="c1"># &#39;Cutout&#39; # NOTE I&#39;ve implement this as random erasing separately</span>
    320. <span class="p">]</span>
    321. <span class="n">_RAND_INCREASING_TRANSFORMS</span> <span class="o">=</span> <span class="p">[</span>
    322. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">,</span>
    323. <span class="s1">&#39;Equalize&#39;</span><span class="p">,</span>
    324. <span class="s1">&#39;Invert&#39;</span><span class="p">,</span>
    325. <span class="s1">&#39;Rotate&#39;</span><span class="p">,</span>
    326. <span class="s1">&#39;PosterizeIncreasing&#39;</span><span class="p">,</span>
    327. <span class="s1">&#39;SolarizeIncreasing&#39;</span><span class="p">,</span>
    328. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">,</span>
    329. <span class="s1">&#39;ColorIncreasing&#39;</span><span class="p">,</span>
    330. <span class="s1">&#39;ContrastIncreasing&#39;</span><span class="p">,</span>
    331. <span class="s1">&#39;BrightnessIncreasing&#39;</span><span class="p">,</span>
    332. <span class="s1">&#39;SharpnessIncreasing&#39;</span><span class="p">,</span>
    333. <span class="s1">&#39;ShearX&#39;</span><span class="p">,</span>
    334. <span class="s1">&#39;ShearY&#39;</span><span class="p">,</span>
    335. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">,</span>
    336. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">,</span>
    337. <span class="c1"># &#39;Cutout&#39; # NOTE I&#39;ve implement this as random erasing separately</span>
    338. <span class="p">]</span>
    339. <span class="c1"># These experimental weights are based loosely on the relative improvements mentioned in paper.</span>
    340. <span class="c1"># They may not result in increased performance, but could likely be tuned to so.</span>
    341. <span class="n">_RAND_CHOICE_WEIGHTS_0</span> <span class="o">=</span> <span class="p">{</span>
    342. <span class="s1">&#39;Rotate&#39;</span><span class="p">:</span> <span class="mf">0.3</span><span class="p">,</span>
    343. <span class="s1">&#39;ShearX&#39;</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">,</span>
    344. <span class="s1">&#39;ShearY&#39;</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">,</span>
    345. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span>
    346. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span>
    347. <span class="s1">&#39;Color&#39;</span><span class="p">:</span> <span class="mf">.025</span><span class="p">,</span>
    348. <span class="s1">&#39;Sharpness&#39;</span><span class="p">:</span> <span class="mf">0.025</span><span class="p">,</span>
    349. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">:</span> <span class="mf">0.025</span><span class="p">,</span>
    350. <span class="s1">&#39;Solarize&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
    351. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
    352. <span class="s1">&#39;Contrast&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
    353. <span class="s1">&#39;Brightness&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
    354. <span class="s1">&#39;Equalize&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
    355. <span class="s1">&#39;Posterize&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
    356. <span class="s1">&#39;Invert&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
    357. <span class="p">}</span>
    358. <span class="k">def</span> <span class="nf">_select_rand_weights</span><span class="p">(</span><span class="n">weight_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">transforms</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    359. <span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span> <span class="ow">or</span> <span class="n">_RAND_TRANSFORMS</span>
    360. <span class="k">assert</span> <span class="n">weight_idx</span> <span class="o">==</span> <span class="mi">0</span> <span class="c1"># only one set of weights currently</span>
    361. <span class="n">rand_weights</span> <span class="o">=</span> <span class="n">_RAND_CHOICE_WEIGHTS_0</span>
    362. <span class="n">probs</span> <span class="o">=</span> <span class="p">[</span><span class="n">rand_weights</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">transforms</span><span class="p">]</span>
    363. <span class="n">probs</span> <span class="o">/=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">probs</span><span class="p">)</span>
    364. <span class="k">return</span> <span class="n">probs</span>
    365. <div class="viewcode-block" id="rand_augment_ops"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.rand_augment_ops">[docs]</a><span class="k">def</span> <span class="nf">rand_augment_ops</span><span class="p">(</span><span class="n">magnitude</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">transforms</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    366. <span class="n">hparams</span> <span class="o">=</span> <span class="n">hparams</span> <span class="ow">or</span> <span class="n">_HPARAMS_DEFAULT</span>
    367. <span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span> <span class="ow">or</span> <span class="n">_RAND_TRANSFORMS</span>
    368. <span class="k">return</span> <span class="p">[</span><span class="n">AugmentOp</span><span class="p">(</span>
    369. <span class="n">name</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="n">magnitude</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="n">hparams</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">transforms</span><span class="p">]</span></div>
    370. <div class="viewcode-block" id="RandAugment"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.RandAugment">[docs]</a><span class="k">class</span> <span class="nc">RandAugment</span><span class="p">:</span>
    371. <span class="sd">&quot;&quot;&quot;</span>
    372. <span class="sd"> Random auto augment class, will select auto augment transforms according to probability weights for each op</span>
    373. <span class="sd"> &quot;&quot;&quot;</span>
    374. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ops</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">choice_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    375. <span class="bp">self</span><span class="o">.</span><span class="n">ops</span> <span class="o">=</span> <span class="n">ops</span>
    376. <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span>
    377. <span class="bp">self</span><span class="o">.</span><span class="n">choice_weights</span> <span class="o">=</span> <span class="n">choice_weights</span>
    378. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
    379. <span class="c1"># no replacement when using weighted choice</span>
    380. <span class="n">ops</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span>
    381. <span class="bp">self</span><span class="o">.</span><span class="n">ops</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">choice_weights</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">choice_weights</span><span class="p">)</span>
    382. <span class="k">for</span> <span class="n">op</span> <span class="ow">in</span> <span class="n">ops</span><span class="p">:</span>
    383. <span class="n">img</span> <span class="o">=</span> <span class="n">op</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
    384. <span class="k">return</span> <span class="n">img</span></div>
    385. <div class="viewcode-block" id="rand_augment_transform"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.rand_augment_transform">[docs]</a><span class="k">def</span> <span class="nf">rand_augment_transform</span><span class="p">(</span><span class="n">config_str</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
    386. <span class="sd">&quot;&quot;&quot;</span>
    387. <span class="sd"> Create a RandAugment transform</span>
    388. <span class="sd"> :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by</span>
    389. <span class="sd"> dashes (&#39;-&#39;). The first section defines the specific variant of rand augment (currently only &#39;rand&#39;). The remaining</span>
    390. <span class="sd"> sections, not order sepecific determine</span>
    391. <span class="sd"> &#39;m&#39; - integer magnitude of rand augment</span>
    392. <span class="sd"> &#39;n&#39; - integer num layers (number of transform ops selected per image)</span>
    393. <span class="sd"> &#39;w&#39; - integer probabiliy weight index (index of a set of weights to influence choice of op)</span>
    394. <span class="sd"> &#39;mstd&#39; - float std deviation of magnitude noise applied</span>
    395. <span class="sd"> &#39;inc&#39; - integer (bool), use augmentations that increase in severity with magnitude (default: 0)</span>
    396. <span class="sd"> Ex &#39;rand-m9-n3-mstd0.5&#39; results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5</span>
    397. <span class="sd"> &#39;rand-mstd1-w0&#39; results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2</span>
    398. <span class="sd"> :param hparams: Other hparams (kwargs) for the RandAugmentation scheme</span>
    399. <span class="sd"> :return: A PyTorch compatible Transform</span>
    400. <span class="sd"> &quot;&quot;&quot;</span>
    401. <span class="n">magnitude</span> <span class="o">=</span> <span class="n">_MAX_MAGNITUDE</span> <span class="c1"># default to _MAX_MAGNITUDE for magnitude (currently 10)</span>
    402. <span class="n">num_layers</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># default to 2 ops per image</span>
    403. <span class="n">weight_idx</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># default to no probability weights for op choice</span>
    404. <span class="n">transforms</span> <span class="o">=</span> <span class="n">_RAND_TRANSFORMS</span>
    405. <span class="n">config</span> <span class="o">=</span> <span class="n">config_str</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;-&#39;</span><span class="p">)</span>
    406. <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">config</span><span class="p">:</span>
    407. <span class="n">cs</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;(\d.*)&#39;</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
    408. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cs</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">:</span>
    409. <span class="k">continue</span>
    410. <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="o">=</span> <span class="n">cs</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
    411. <span class="k">if</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;mstd&#39;</span><span class="p">:</span>
    412. <span class="c1"># noise param injected via hparams for now</span>
    413. <span class="n">hparams</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s1">&#39;magnitude_std&#39;</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="n">val</span><span class="p">))</span>
    414. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;inc&#39;</span><span class="p">:</span>
    415. <span class="k">if</span> <span class="nb">bool</span><span class="p">(</span><span class="n">val</span><span class="p">):</span>
    416. <span class="n">transforms</span> <span class="o">=</span> <span class="n">_RAND_INCREASING_TRANSFORMS</span>
    417. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;m&#39;</span><span class="p">:</span>
    418. <span class="n">magnitude</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
    419. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;n&#39;</span><span class="p">:</span>
    420. <span class="n">num_layers</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
    421. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;w&#39;</span><span class="p">:</span>
    422. <span class="n">weight_idx</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
    423. <span class="k">else</span><span class="p">:</span>
    424. <span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s1">&#39;Unknown RandAugment config section&#39;</span>
    425. <span class="n">ra_ops</span> <span class="o">=</span> <span class="n">rand_augment_ops</span><span class="p">(</span><span class="n">magnitude</span><span class="o">=</span><span class="n">magnitude</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="n">hparams</span><span class="p">,</span> <span class="n">transforms</span><span class="o">=</span><span class="n">transforms</span><span class="p">)</span>
    426. <span class="n">choice_weights</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">weight_idx</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">_select_rand_weights</span><span class="p">(</span><span class="n">weight_idx</span><span class="p">)</span>
    427. <span class="k">return</span> <span class="n">RandAugment</span><span class="p">(</span><span class="n">ra_ops</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">choice_weights</span><span class="o">=</span><span class="n">choice_weights</span><span class="p">)</span></div>
    428. </pre></div>
    429. </div>
    430. </div>
    431. <footer>
    432. <hr/>
    433. <div role="contentinfo">
    434. <p>&#169; Copyright 2021, SuperGradients team.</p>
    435. </div>
    436. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    437. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    438. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    439. </footer>
    440. </div>
    441. </div>
    442. </section>
    443. </div>
    444. <script>
    445. jQuery(function () {
    446. SphinxRtdTheme.Navigation.enable(true);
    447. });
    448. </script>
    449. </body>
    450. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.classification_datasets.cifar &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../../_static/jquery.js"></script>
    16. <script src="../../../../../_static/underscore.js"></script>
    17. <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../../_static/doctools.js"></script>
    19. <script src="../../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.training.datasets.classification_datasets.cifar</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.training.datasets.classification_datasets.cifar</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Union</span>
    86. <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">Compose</span>
    87. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
    88. <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
    89. <span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">CIFAR10</span><span class="p">,</span> <span class="n">CIFAR100</span>
    90. <div class="viewcode-block" id="Cifar10"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.Cifar10">[docs]</a><span class="k">class</span> <span class="nc">Cifar10</span><span class="p">(</span><span class="n">CIFAR10</span><span class="p">):</span>
    91. <span class="sd">&quot;&quot;&quot;</span>
    92. <span class="sd"> CIFAR10 Dataset</span>
    93. <span class="sd"> :param root: Path for the data to be extracted</span>
    94. <span class="sd"> :param train: Bool to load training (True) or validation (False) part of the dataset</span>
    95. <span class="sd"> :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose</span>
    96. <span class="sd"> :param target_transform: Transform to apply to target output</span>
    97. <span class="sd"> :param download: Download (True) the dataset from source</span>
    98. <span class="sd"> &quot;&quot;&quot;</span>
    99. <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">TransformsFactory</span><span class="p">())</span>
    100. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    101. <span class="bp">self</span><span class="p">,</span>
    102. <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    103. <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    104. <span class="n">transforms</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    105. <span class="n">target_transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    106. <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    107. <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
    108. <span class="c1"># TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS</span>
    109. <span class="c1"># TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)</span>
    110. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">transforms</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
    111. <span class="n">transforms</span> <span class="o">=</span> <span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span><span class="p">)</span>
    112. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar10</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    113. <span class="n">root</span><span class="o">=</span><span class="n">root</span><span class="p">,</span>
    114. <span class="n">train</span><span class="o">=</span><span class="n">train</span><span class="p">,</span>
    115. <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span>
    116. <span class="n">target_transform</span><span class="o">=</span><span class="n">target_transform</span><span class="p">,</span>
    117. <span class="n">download</span><span class="o">=</span><span class="n">download</span><span class="p">,</span>
    118. <span class="p">)</span></div>
    119. <div class="viewcode-block" id="Cifar100"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.Cifar100">[docs]</a><span class="k">class</span> <span class="nc">Cifar100</span><span class="p">(</span><span class="n">CIFAR100</span><span class="p">):</span>
    120. <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">TransformsFactory</span><span class="p">())</span>
    121. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    122. <span class="bp">self</span><span class="p">,</span>
    123. <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    124. <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    125. <span class="n">transforms</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    126. <span class="n">target_transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    127. <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    128. <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
    129. <span class="sd">&quot;&quot;&quot;</span>
    130. <span class="sd"> CIFAR100 Dataset</span>
    131. <span class="sd"> :param root: Path for the data to be extracted</span>
    132. <span class="sd"> :param train: Bool to load training (True) or validation (False) part of the dataset</span>
    133. <span class="sd"> :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose</span>
    134. <span class="sd"> :param target_transform: Transform to apply to target output</span>
    135. <span class="sd"> :param download: Download (True) the dataset from source</span>
    136. <span class="sd"> &quot;&quot;&quot;</span>
    137. <span class="c1"># TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS</span>
    138. <span class="c1"># TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)</span>
    139. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">transforms</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
    140. <span class="n">transforms</span> <span class="o">=</span> <span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span><span class="p">)</span>
    141. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar100</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    142. <span class="n">root</span><span class="o">=</span><span class="n">root</span><span class="p">,</span>
    143. <span class="n">train</span><span class="o">=</span><span class="n">train</span><span class="p">,</span>
    144. <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span>
    145. <span class="n">target_transform</span><span class="o">=</span><span class="n">target_transform</span><span class="p">,</span>
    146. <span class="n">download</span><span class="o">=</span><span class="n">download</span><span class="p">,</span>
    147. <span class="p">)</span></div>
    148. </pre></div>
    149. </div>
    150. </div>
    151. <footer>
    152. <hr/>
    153. <div role="contentinfo">
    154. <p>&#169; Copyright 2021, SuperGradients team.</p>
    155. </div>
    156. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    157. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    158. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    159. </footer>
    160. </div>
    161. </div>
    162. </section>
    163. </div>
    164. <script>
    165. jQuery(function () {
    166. SphinxRtdTheme.Navigation.enable(true);
    167. });
    168. </script>
    169. </body>
    170. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.classification_datasets.imagenet_dataset &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../../_static/jquery.js"></script>
    16. <script src="../../../../../_static/underscore.js"></script>
    17. <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../../_static/doctools.js"></script>
    19. <script src="../../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.training.datasets.classification_datasets.imagenet_dataset</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.training.datasets.classification_datasets.imagenet_dataset</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
    86. <span class="kn">import</span> <span class="nn">torchvision.datasets</span> <span class="k">as</span> <span class="nn">torch_datasets</span>
    87. <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">Compose</span>
    88. <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
    89. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
    90. <div class="viewcode-block" id="ImageNetDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.ImageNetDataset">[docs]</a><span class="k">class</span> <span class="nc">ImageNetDataset</span><span class="p">(</span><span class="n">torch_datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">):</span>
    91. <span class="sd">&quot;&quot;&quot;ImageNetDataset dataset&quot;&quot;&quot;</span>
    92. <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
    93. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">transforms</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    94. <span class="c1"># TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS</span>
    95. <span class="c1"># TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)</span>
    96. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">transforms</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
    97. <span class="n">transforms</span> <span class="o">=</span> <span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span><span class="p">)</span>
    98. <span class="nb">super</span><span class="p">(</span><span class="n">ImageNetDataset</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    99. </pre></div>
    100. </div>
    101. </div>
    102. <footer>
    103. <hr/>
    104. <div role="contentinfo">
    105. <p>&#169; Copyright 2021, SuperGradients team.</p>
    106. </div>
    107. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    108. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    109. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    110. </footer>
    111. </div>
    112. </div>
    113. </section>
    114. </div>
    115. <script>
    116. jQuery(function () {
    117. SphinxRtdTheme.Navigation.enable(true);
    118. });
    119. </script>
    120. </body>
    121. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.data_augmentation &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.data_augmentation &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -91,9 +93,9 @@
     <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">RandomErasing</span>
     <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">RandomErasing</span>
     
     
     
     
    -<div class="viewcode-block" id="DataAugmentation"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation">[docs]</a><span class="k">class</span> <span class="nc">DataAugmentation</span><span class="p">:</span>
    +<div class="viewcode-block" id="DataAugmentation"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation">[docs]</a><span class="k">class</span> <span class="nc">DataAugmentation</span><span class="p">:</span>
     
     
    -<div class="viewcode-block" id="DataAugmentation.to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation.to_tensor">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="DataAugmentation.to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation.to_tensor">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">to_tensor</span><span class="p">():</span>
         <span class="k">def</span> <span class="nf">to_tensor</span><span class="p">():</span>
             <span class="k">def</span> <span class="nf">_to_tensor</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
             <span class="k">def</span> <span class="nf">_to_tensor</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
                 <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
                 <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
    @@ -104,7 +106,7 @@
     
     
             <span class="k">return</span> <span class="n">_to_tensor</span></div>
             <span class="k">return</span> <span class="n">_to_tensor</span></div>
     
     
    -<div class="viewcode-block" id="DataAugmentation.normalize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation.normalize">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="DataAugmentation.normalize"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation.normalize">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">):</span>
             <span class="n">mean</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mean</span><span class="p">)</span>
             <span class="n">mean</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mean</span><span class="p">)</span>
             <span class="n">std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">std</span><span class="p">)</span>
             <span class="n">std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">std</span><span class="p">)</span>
    @@ -116,7 +118,7 @@
     
     
             <span class="k">return</span> <span class="n">_normalize</span></div>
             <span class="k">return</span> <span class="n">_normalize</span></div>
     
     
    -<div class="viewcode-block" id="DataAugmentation.cutout"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation.cutout">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="DataAugmentation.cutout"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation.cutout">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">cutout</span><span class="p">(</span><span class="n">mask_size</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">cutout_inside</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">mask_color</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="m
         <span class="k">def</span> <span class="nf">cutout</span><span class="p">(</span><span class="n">mask_size</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">cutout_inside</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">mask_color</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="m
             <span class="n">mask_size_half</span> <span class="o">=</span> <span class="n">mask_size</span> <span class="o">//</span> <span class="mi">2</span>
             <span class="n">mask_size_half</span> <span class="o">=</span> <span class="n">mask_size</span> <span class="o">//</span> <span class="mi">2</span>
             <span class="n">offset</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">mask_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
             <span class="n">offset</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">mask_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
    @@ -159,7 +161,7 @@
                                 <span class="p">[</span><span class="o">-</span><span class="mf">0.5836</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6948</span><span class="p">,</span> <span class="mf">0.4203</span><span class="p">]])}</span>
                                 <span class="p">[</span><span class="o">-</span><span class="mf">0.5836</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6948</span><span class="p">,</span> <span class="mf">0.4203</span><span class="p">]])}</span>
     
     
     
     
    -<div class="viewcode-block" id="Lighting"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.Lighting">[docs]</a><span class="k">class</span> <span class="nc">Lighting</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    +<span class="k">class</span> <span class="nc">Lighting</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Lighting noise(AlexNet - style PCA - based noise)</span>
     <span class="sd">    Lighting noise(AlexNet - style PCA - based noise)</span>
     <span class="sd">    Taken from fastai Imagenet training -</span>
     <span class="sd">    Taken from fastai Imagenet training -</span>
    @@ -183,10 +185,10 @@
                 <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="n">alpha</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> \
                 <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="n">alpha</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> \
                 <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eigval</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span 
                 <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eigval</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span 
                 <span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
                 <span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
    -        <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">rgb</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">img</span><span
    +        <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">rgb</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">img</span><span
     
     
     
     
    -<div class="viewcode-block" id="RandomErase"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.RandomErase">[docs]</a><span class="k">class</span> <span class="nc">RandomErase</span><span class="p">(</span><span class="n">RandomErasing</span><span class="p">):</span>
    +<span class="k">class</span> <span class="nc">RandomErase</span><span class="p">(</span><span class="n">RandomErasing</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    A simple class that translates the parameters supported in SuperGradient&#39;s code base</span>
     <span class="sd">    A simple class that translates the parameters supported in SuperGradient&#39;s code base</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    @@ -197,7 +199,7 @@
                 <span class="n">value</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
                 <span class="n">value</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
             <span class="k">except</span> <span class="ne">ValueError</span><span class="p">:</span>
             <span class="k">except</span> <span class="ne">ValueError</span><span class="p">:</span>
                 <span class="k">pass</span>
                 <span class="k">pass</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">probability</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">probability</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -227,4 +229,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    559
    560
    561
    562
    563
    564
    565
    566
    567
    568
    569
    570
    571
    572
    573
    574
    575
    576
    577
    578
    579
    580
    581
    582
    583
    584
    585
    586
    587
    588
    589
    590
    591
    592
    593
    594
    595
    596
    597
    598
    599
    600
    601
    602
    603
    604
    605
    606
    607
    608
    609
    610
    611
    612
    613
    614
    615
    616
    617
    618
    619
    620
    621
    622
    623
    624
    625
    626
    627
    628
    629
    630
    631
    632
    633
    634
    635
    636
    637
    638
    639
    640
    641
    642
    643
    644
    645
    646
    647
    648
    649
    650
    651
    652
    653
    654
    655
    656
    657
    658
    659
    660
    661
    662
    663
    664
    665
    666
    667
    668
    669
    670
    671
    672
    673
    674
    675
    676
    677
    678
    679
    680
    681
    682
    683
    684
    685
    686
    687
    688
    689
    690
    691
    692
    693
    694
    695
    696
    697
    698
    699
    700
    701
    702
    703
    704
    705
    706
    707
    708
    709
    710
    711
    712
    713
    714
    715
    716
    717
    718
    719
    720
    721
    722
    723
    724
    725
    726
    727
    728
    729
    730
    731
    732
    733
    734
    735
    736
    737
    738
    739
    740
    741
    742
    743
    744
    745
    746
    747
    748
    749
    750
    751
    752
    753
    754
    755
    756
    757
    758
    759
    760
    761
    762
    763
    764
    765
    766
    767
    768
    769
    770
    771
    772
    773
    774
    775
    776
    777
    778
    779
    780
    781
    782
    783
    784
    785
    786
    787
    788
    789
    790
    791
    792
    793
    794
    795
    796
    797
    798
    799
    800
    801
    802
    803
    804
    805
    806
    807
    808
    809
    810
    811
    812
    813
    814
    815
    816
    817
    818
    819
    820
    821
    822
    823
    824
    825
    826
    827
    828
    829
    830
    831
    832
    833
    834
    835
    836
    837
    838
    839
    840
    841
    842
    843
    844
    845
    846
    847
    848
    849
    850
    851
    852
    853
    854
    855
    856
    857
    858
    859
    860
    861
    862
    863
    864
    865
    866
    867
    868
    869
    870
    871
    872
    873
    874
    875
    876
    877
    878
    879
    880
    881
    882
    883
    884
    885
    886
    887
    888
    889
    890
    891
    892
    893
    894
    895
    896
    897
    898
    899
    900
    901
    902
    903
    904
    905
    906
    907
    908
    909
    910
    911
    912
    913
    914
    915
    916
    917
    918
    919
    920
    921
    922
    923
    924
    925
    926
    927
    928
    929
    930
    931
    932
    933
    934
    935
    936
    937
    938
    939
    940
    941
    942
    943
    944
    945
    946
    947
    948
    949
    950
    951
    952
    953
    954
    955
    956
    957
    958
    959
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.dataset_interfaces.dataset_interface &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../../_static/jquery.js"></script>
    15. <script src="../../../../../_static/underscore.js"></script>
    16. <script src="../../../../../_static/doctools.js"></script>
    17. <script src="../../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.datasets.dataset_interfaces.dataset_interface</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.datasets.dataset_interfaces.dataset_interface</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">os</span>
    84. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    85. <span class="kn">import</span> <span class="nn">torch</span>
    86. <span class="kn">import</span> <span class="nn">torchvision</span>
    87. <span class="kn">import</span> <span class="nn">torchvision.datasets</span> <span class="k">as</span> <span class="nn">datasets</span>
    88. <span class="kn">from</span> <span class="nn">torch.utils.data.distributed</span> <span class="kn">import</span> <span class="n">DistributedSampler</span>
    89. <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">ConcatDataset</span><span class="p">,</span> <span class="n">BatchSampler</span><span class="p">,</span> <span class="n">DataLoader</span>
    90. <span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transforms</span>
    91. <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">DatasetDataInterface</span>
    92. <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">AWS_ENV_NAME</span>
    93. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    94. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
    95. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="n">get_local_rank</span><span class="p">,</span> <span class="n">wait_for_the_master</span>
    96. <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
    97. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
    98. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets</span> <span class="kn">import</span> <span class="n">datasets_utils</span><span class="p">,</span> <span class="n">DataAugmentation</span>
    99. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_conf</span> <span class="kn">import</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span>
    100. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.data_augmentation</span> <span class="kn">import</span> <span class="n">Lighting</span><span class="p">,</span> <span class="n">RandomErase</span>
    101. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.mixup</span> <span class="kn">import</span> <span class="n">CollateMixup</span>
    102. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets</span> <span class="kn">import</span> <span class="n">COCODetectionDataset</span><span class="p">,</span> <span class="n">PascalVOCDetectionDataset</span>
    103. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers.infinite_sampler</span> <span class="kn">import</span> <span class="n">InfiniteSampler</span>
    104. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets</span> <span class="kn">import</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span> \
    105. <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">,</span> <span class="n">CoCoSegmentationDataSet</span>
    106. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</span> <span class="kn">import</span> <span class="n">CityscapesDataset</span>
    107. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation</span> <span class="kn">import</span> \
    108. <span class="n">SuperviselyPersonsDataset</span>
    109. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers.repeated_augmentation_sampler</span> <span class="kn">import</span> <span class="n">RepeatAugSampler</span>
    110. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_utils</span> <span class="kn">import</span> <span class="n">RandomResizedCropAndInterpolation</span><span class="p">,</span> <span class="n">worker_init_reset_seed</span>
    111. <span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionMosaic</span><span class="p">,</span> <span class="n">DetectionMixup</span><span class="p">,</span> <span class="n">DetectionRandomAffine</span><span class="p">,</span>\
    112. <span class="n">DetectionTargetsFormatTransform</span><span class="p">,</span> <span class="n">DetectionPaddedRescale</span><span class="p">,</span> <span class="n">DetectionHSV</span><span class="p">,</span> <span class="n">DetectionHorizontalFlip</span>
    113. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">IllegalDatasetParameterException</span>
    114. <span class="n">default_dataset_params</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;batch_size&quot;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s2">&quot;val_batch_size&quot;</span><span class="p">:</span> <span class="mi">200</span><span class="p">,</span> <span class="s2">&quot;test_batch_size&quot;</span><span class="p">:</span> <span class="mi">200</span><span class="p">,</span> <span class="s2">&quot;dataset_dir&quot;</span><span class="p">:</span> <span class="s2">&quot;./data/&quot;</span><span class="p">,</span>
    115. <span class="s2">&quot;s3_link&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">}</span>
    116. <span class="n">LIBRARY_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
    117. <span class="s2">&quot;cifar10&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;class&#39;</span><span class="p">:</span> <span class="n">datasets</span><span class="o">.</span><span class="n">CIFAR10</span><span class="p">,</span> <span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.4914</span><span class="p">,</span> <span class="mf">0.4822</span><span class="p">,</span> <span class="mf">0.4465</span><span class="p">),</span> <span class="s1">&#39;std&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.2023</span><span class="p">,</span> <span class="mf">0.1994</span><span class="p">,</span> <span class="mf">0.2010</span><span class="p">)},</span>
    118. <span class="s2">&quot;cifar100&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;class&#39;</span><span class="p">:</span> <span class="n">datasets</span><span class="o">.</span><span class="n">CIFAR100</span><span class="p">,</span> <span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.5071</span><span class="p">,</span> <span class="mf">0.4865</span><span class="p">,</span> <span class="mf">0.4409</span><span class="p">),</span> <span class="s1">&#39;std&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.2673</span><span class="p">,</span> <span class="mf">0.2564</span><span class="p">,</span> <span class="mf">0.2762</span><span class="p">)},</span>
    119. <span class="s2">&quot;SVHN&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;class&#39;</span><span class="p">:</span> <span class="n">datasets</span><span class="o">.</span><span class="n">SVHN</span><span class="p">,</span> <span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;std&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">}</span>
    120. <span class="p">}</span>
    121. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    122. <div class="viewcode-block" id="DatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">DatasetInterface</span><span class="p">:</span>
    123. <span class="sd">&quot;&quot;&quot;</span>
    124. <span class="sd"> DatasetInterface - This class manages all of the &quot;communiation&quot; the Model has with the Data Sets</span>
    125. <span class="sd"> &quot;&quot;&quot;</span>
    126. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">train_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">test_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    127. <span class="sd">&quot;&quot;&quot;</span>
    128. <span class="sd"> @param train_loader: torch.utils.data.Dataloader (optional) dataloader for training.</span>
    129. <span class="sd"> @param test_loader: torch.utils.data.Dataloader (optional) dataloader for testing.</span>
    130. <span class="sd"> @param classes: list of classes.</span>
    131. <span class="sd"> Note: the above parameters will be discarded in case dataset_params is passed.</span>
    132. <span class="sd"> @param dataset_params:</span>
    133. <span class="sd"> - `batch_size` : int (default=64)</span>
    134. <span class="sd"> Number of examples per batch for training. Large batch sizes are recommended.</span>
    135. <span class="sd"> - `val_batch_size` : int (default=200)</span>
    136. <span class="sd"> Number of examples per batch for validation. Large batch sizes are recommended.</span>
    137. <span class="sd"> - `dataset_dir` : str (default=&quot;./data/&quot;)</span>
    138. <span class="sd"> Directory location for the data. Data will be downloaded to this directory when getting it from a</span>
    139. <span class="sd"> remote url.</span>
    140. <span class="sd"> - `s3_link` : str (default=None)</span>
    141. <span class="sd"> remote s3 link to download the data (optional).</span>
    142. <span class="sd"> - `aug_repeat_count` : int (default=0)</span>
    143. <span class="sd"> amount of repetitions (each repetition of an example is augmented differently) for each</span>
    144. <span class="sd"> example for the trainset.</span>
    145. <span class="sd"> &quot;&quot;&quot;</span>
    146. <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">default_dataset_params</span><span class="p">)</span>
    147. <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">dataset_params</span><span class="p">)</span>
    148. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
    149. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span> <span class="o">=</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">val_loader</span><span class="p">,</span> <span class="n">test_loader</span>
    150. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
    151. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="mi">1</span>
    152. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">s3_link</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    153. <span class="bp">self</span><span class="o">.</span><span class="n">download_from_cloud</span><span class="p">()</span>
    154. <div class="viewcode-block" id="DatasetInterface.download_from_cloud"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.download_from_cloud">[docs]</a> <span class="k">def</span> <span class="nf">download_from_cloud</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    155. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">s3_link</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    156. <span class="n">env_name</span> <span class="o">=</span> <span class="n">AWS_ENV_NAME</span>
    157. <span class="n">downloader</span> <span class="o">=</span> <span class="n">DatasetDataInterface</span><span class="p">(</span><span class="n">env</span><span class="o">=</span><span class="n">env_name</span><span class="p">)</span>
    158. <span class="n">target_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span>
    159. <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">target_dir</span><span class="p">):</span>
    160. <span class="n">os</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">target_dir</span><span class="p">)</span>
    161. <span class="n">downloader</span><span class="o">.</span><span class="n">load_remote_dataset_file</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">s3_link</span><span class="p">,</span> <span class="n">target_dir</span><span class="p">)</span></div>
    162. <div class="viewcode-block" id="DatasetInterface.build_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.build_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">build_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    163. <span class="n">test_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">distributed_sampler</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    164. <span class="sd">&quot;&quot;&quot;</span>
    165. <span class="sd"> define train, val (and optionally test) loaders. The method deals separately with distributed training and standard</span>
    166. <span class="sd"> (non distributed, or parallel training). In the case of distributed training we need to rely on distributed</span>
    167. <span class="sd"> samplers.</span>
    168. <span class="sd"> :param batch_size_factor: int - factor to multiply the batch size (usually for multi gpu)</span>
    169. <span class="sd"> :param num_workers: int - number of workers (parallel processes) for dataloaders</span>
    170. <span class="sd"> :param train_batch_size: int - batch size for train loader, if None will be taken from dataset_params</span>
    171. <span class="sd"> :param val_batch_size: int - batch size for val loader, if None will be taken from dataset_params</span>
    172. <span class="sd"> :param distributed_sampler: boolean flag for distributed training mode</span>
    173. <span class="sd"> :return: train_loader, val_loader, classes: list of classes</span>
    174. <span class="sd"> &quot;&quot;&quot;</span>
    175. <span class="c1"># CHANGE THE BATCH SIZE ACCORDING TO THE NUMBER OF DEVICES - ONLY IN NON-DISTRIBUTED TRAINING MODE</span>
    176. <span class="c1"># IN DISTRIBUTED MODE WE NEED DISTRIBUTED SAMPLERS</span>
    177. <span class="c1"># NO SHUFFLE IN DISTRIBUTED TRAINING</span>
    178. <span class="n">aug_repeat_count</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;aug_repeat_count&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    179. <span class="k">if</span> <span class="n">aug_repeat_count</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">distributed_sampler</span><span class="p">:</span>
    180. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;repeated augmentation is only supported with DDP.&quot;</span><span class="p">)</span>
    181. <span class="k">if</span> <span class="n">distributed_sampler</span><span class="p">:</span>
    182. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="mi">1</span>
    183. <span class="n">train_sampler</span> <span class="o">=</span> <span class="n">RepeatAugSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span>
    184. <span class="n">num_repeats</span><span class="o">=</span><span class="n">aug_repeat_count</span><span class="p">)</span> <span class="k">if</span> <span class="n">aug_repeat_count</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">DistributedSampler</span><span class="p">(</span>
    185. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">)</span>
    186. <span class="n">val_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">)</span>
    187. <span class="n">test_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">testset</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
    188. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">False</span>
    189. <span class="k">else</span><span class="p">:</span>
    190. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="n">batch_size_factor</span>
    191. <span class="n">train_sampler</span> <span class="o">=</span> <span class="kc">None</span>
    192. <span class="n">val_sampler</span> <span class="o">=</span> <span class="kc">None</span>
    193. <span class="n">test_sampler</span> <span class="o">=</span> <span class="kc">None</span>
    194. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">True</span>
    195. <span class="k">if</span> <span class="n">train_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    196. <span class="n">train_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
    197. <span class="k">if</span> <span class="n">val_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    198. <span class="n">val_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
    199. <span class="k">if</span> <span class="n">test_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    200. <span class="n">test_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">test_batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
    201. <span class="n">train_loader_drop_last</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_loader_drop_last&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    202. <span class="n">cutmix</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;cutmix&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
    203. <span class="n">cutmix_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;cutmix_params&#39;</span><span class="p">)</span>
    204. <span class="c1"># WRAPPING collate_fn</span>
    205. <span class="n">train_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span> <span class="s1">&#39;collate_fn&#39;</span><span class="p">)</span>
    206. <span class="n">val_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span> <span class="s1">&#39;collate_fn&#39;</span><span class="p">)</span>
    207. <span class="n">test_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">testset</span><span class="p">,</span> <span class="s1">&#39;collate_fn&#39;</span><span class="p">)</span>
    208. <span class="k">if</span> <span class="n">cutmix</span> <span class="ow">and</span> <span class="n">train_collate_fn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    209. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;cutmix and collate function cannot be used together&quot;</span><span class="p">)</span>
    210. <span class="k">if</span> <span class="n">cutmix</span><span class="p">:</span>
    211. <span class="c1"># FIXME - cutmix should be available only in classification dataset. once we make sure all classification</span>
    212. <span class="c1"># datasets inherit from the same super class, we should move cutmix code to that class</span>
    213. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;Cutmix/mixup was enabled. This feature is currently supported only &quot;</span>
    214. <span class="s2">&quot;for classification datasets.&quot;</span><span class="p">)</span>
    215. <span class="n">train_collate_fn</span> <span class="o">=</span> <span class="n">CollateMixup</span><span class="p">(</span><span class="o">**</span><span class="n">cutmix_params</span><span class="p">)</span>
    216. <span class="c1"># FIXME - UNDERSTAND IF THE num_replicas VARIBALE IS NEEDED</span>
    217. <span class="c1"># train_sampler = DistributedSampler(self.trainset,</span>
    218. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
    219. <span class="c1"># val_sampler = DistributedSampler(self.valset,</span>
    220. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
    221. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span>
    222. <span class="n">batch_size</span><span class="o">=</span><span class="n">train_batch_size</span><span class="p">,</span>
    223. <span class="n">shuffle</span><span class="o">=</span><span class="n">train_shuffle</span><span class="p">,</span>
    224. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    225. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    226. <span class="n">sampler</span><span class="o">=</span><span class="n">train_sampler</span><span class="p">,</span>
    227. <span class="n">collate_fn</span><span class="o">=</span><span class="n">train_collate_fn</span><span class="p">,</span>
    228. <span class="n">drop_last</span><span class="o">=</span><span class="n">train_loader_drop_last</span><span class="p">)</span>
    229. <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span>
    230. <span class="n">batch_size</span><span class="o">=</span><span class="n">val_batch_size</span><span class="p">,</span>
    231. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    232. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    233. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    234. <span class="n">sampler</span><span class="o">=</span><span class="n">val_sampler</span><span class="p">,</span>
    235. <span class="n">collate_fn</span><span class="o">=</span><span class="n">val_collate_fn</span><span class="p">)</span>
    236. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    237. <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">testset</span><span class="p">,</span>
    238. <span class="n">batch_size</span><span class="o">=</span><span class="n">test_batch_size</span><span class="p">,</span>
    239. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    240. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    241. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    242. <span class="n">sampler</span><span class="o">=</span><span class="n">test_sampler</span><span class="p">,</span>
    243. <span class="n">collate_fn</span><span class="o">=</span><span class="n">test_collate_fn</span><span class="p">)</span>
    244. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
    245. <div class="viewcode-block" id="DatasetInterface.get_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.get_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">get_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    246. <span class="sd">&quot;&quot;&quot;</span>
    247. <span class="sd"> Get self.train_loader, self.val_loader, self.test_loader, self.classes.</span>
    248. <span class="sd"> If the data loaders haven&#39;t been initialized yet, build them first.</span>
    249. <span class="sd"> :param kwargs: kwargs are passed to build_data_loaders.</span>
    250. <span class="sd"> &quot;&quot;&quot;</span>
    251. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    252. <span class="bp">self</span><span class="o">.</span><span class="n">build_data_loaders</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    253. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span></div>
    254. <div class="viewcode-block" id="DatasetInterface.get_val_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.get_val_sample">[docs]</a> <span class="k">def</span> <span class="nf">get_val_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
    255. <span class="k">if</span> <span class="n">num_samples</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">):</span>
    256. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">&quot;Tried to load more samples than val-set size&quot;</span><span class="p">)</span>
    257. <span class="k">if</span> <span class="n">num_samples</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
    258. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    259. <span class="k">else</span><span class="p">:</span>
    260. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">num_samples</span><span class="p">]</span></div>
    261. <div class="viewcode-block" id="DatasetInterface.get_dataset_params"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.get_dataset_params">[docs]</a> <span class="k">def</span> <span class="nf">get_dataset_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    262. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span></div>
    263. <div class="viewcode-block" id="DatasetInterface.print_dataset_details"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.print_dataset_details">[docs]</a> <span class="k">def</span> <span class="nf">print_dataset_details</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    264. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2"> training samples, </span><span class="si">{}</span><span class="s2"> val samples, </span><span class="si">{}</span><span class="s2"> classes&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">),</span>
    265. <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span><span class="p">)))</span></div></div>
    266. <div class="viewcode-block" id="ExternalDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ExternalDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ExternalDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    267. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">val_loader</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
    268. <span class="sd">&quot;&quot;&quot;</span>
    269. <span class="sd"> ExternalDatasetInterface - A wrapper for external dataset interface that gets dataloaders from keras/TF</span>
    270. <span class="sd"> and converts them to Torch-like dataloaders that return torch.Tensors after</span>
    271. <span class="sd"> optional collate_fn while maintaining the same interface (connect_dataset_interface etc.)</span>
    272. <span class="sd"> :train_loader: The external train_loader</span>
    273. <span class="sd"> :val_loader: The external val_loader</span>
    274. <span class="sd"> :num_classes: The number of classes</span>
    275. <span class="sd"> :dataset_params The dict that includes the batch_size and/or the collate_fn</span>
    276. <span class="sd"> :return: DataLoaders that generate torch.Tensors batches after collate_fn</span>
    277. <span class="sd"> &quot;&quot;&quot;</span>
    278. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
    279. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">train_loader</span>
    280. <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="o">=</span> <span class="n">val_loader</span>
    281. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">num_classes</span>
    282. <div class="viewcode-block" id="ExternalDatasetInterface.get_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ExternalDatasetInterface.get_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">get_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    283. <span class="n">val_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">distributed_sampler</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    284. <span class="c1"># CHANGE THE BATCH SIZE ACCORDING TO THE NUMBER OF DEVICES - ONLY IN NON-DISTRIBUED TRAINING MODE</span>
    285. <span class="c1"># IN DISTRIBUTED MODE WE NEED DISTRIBUTED SAMPLERS</span>
    286. <span class="c1"># NO SHUFFLE IN DISTRIBUTED TRAINING</span>
    287. <span class="k">if</span> <span class="n">distributed_sampler</span><span class="p">:</span>
    288. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="mi">1</span>
    289. <span class="n">train_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    290. <span class="n">val_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">)</span>
    291. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">False</span>
    292. <span class="k">else</span><span class="p">:</span>
    293. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="n">batch_size_factor</span>
    294. <span class="n">train_sampler</span> <span class="o">=</span> <span class="kc">None</span>
    295. <span class="n">val_sampler</span> <span class="o">=</span> <span class="kc">None</span>
    296. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">True</span>
    297. <span class="k">if</span> <span class="n">train_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    298. <span class="n">train_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
    299. <span class="k">if</span> <span class="n">val_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    300. <span class="n">val_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
    301. <span class="n">train_loader_drop_last</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_loader_drop_last&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    302. <span class="c1"># WRAPPING collate_fn</span>
    303. <span class="n">train_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_collate_fn&#39;</span><span class="p">)</span>
    304. <span class="n">val_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;val_collate_fn&#39;</span><span class="p">)</span>
    305. <span class="c1"># FIXME - UNDERSTAND IF THE num_replicas VARIBALE IS NEEDED</span>
    306. <span class="c1"># train_sampler = DistributedSampler(self.trainset,</span>
    307. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
    308. <span class="c1"># val_sampler = DistributedSampler(self.valset,</span>
    309. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
    310. <span class="bp">self</span><span class="o">.</span><span class="n">torch_train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span>
    311. <span class="n">batch_size</span><span class="o">=</span><span class="n">train_batch_size</span><span class="p">,</span>
    312. <span class="n">shuffle</span><span class="o">=</span><span class="n">train_shuffle</span><span class="p">,</span>
    313. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    314. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    315. <span class="n">sampler</span><span class="o">=</span><span class="n">train_sampler</span><span class="p">,</span>
    316. <span class="n">collate_fn</span><span class="o">=</span><span class="n">train_collate_fn</span><span class="p">,</span>
    317. <span class="n">drop_last</span><span class="o">=</span><span class="n">train_loader_drop_last</span><span class="p">)</span>
    318. <span class="bp">self</span><span class="o">.</span><span class="n">torch_val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span><span class="p">,</span>
    319. <span class="n">batch_size</span><span class="o">=</span><span class="n">val_batch_size</span><span class="p">,</span>
    320. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    321. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    322. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    323. <span class="n">sampler</span><span class="o">=</span><span class="n">val_sampler</span><span class="p">,</span>
    324. <span class="n">collate_fn</span><span class="o">=</span><span class="n">val_collate_fn</span><span class="p">)</span>
    325. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">torch_train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">torch_val_loader</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span></div></div>
    326. <div class="viewcode-block" id="LibraryDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.LibraryDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">LibraryDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    327. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;cifar10&quot;</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">to_cutout</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    328. <span class="nb">super</span><span class="p">(</span><span class="n">LibraryDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
    329. <span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span> <span class="o">=</span> <span class="n">name</span>
    330. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">LIBRARY_DATASETS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    331. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;dataset not found&#39;</span><span class="p">)</span>
    332. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span> <span class="o">=</span> <span class="n">LIBRARY_DATASETS</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">]</span>
    333. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    334. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">SVHN</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s1">&#39;train&#39;</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    335. <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">())</span>
    336. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">datasets_utils</span><span class="o">.</span><span class="n">get_mean_and_std</span><span class="p">(</span><span class="n">trainset</span><span class="p">)</span>
    337. <span class="c1"># OVERWRITE MEAN AND STD IF DEFINED IN DATASET PARAMS</span>
    338. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_mean&#39;</span><span class="p">,</span>
    339. <span class="n">default_val</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">])</span>
    340. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_std&#39;</span><span class="p">,</span>
    341. <span class="n">default_val</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">])</span>
    342. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">32</span><span class="p">)</span>
    343. <span class="k">if</span> <span class="n">to_cutout</span><span class="p">:</span>
    344. <span class="n">transform_train</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    345. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
    346. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
    347. <span class="n">DataAugmentation</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]),</span>
    348. <span class="n">DataAugmentation</span><span class="o">.</span><span class="n">cutout</span><span class="p">(</span><span class="mi">16</span><span class="p">),</span>
    349. <span class="n">DataAugmentation</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">()</span>
    350. <span class="p">])</span>
    351. <span class="k">else</span><span class="p">:</span>
    352. <span class="n">transform_train</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    353. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
    354. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
    355. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    356. <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]),</span>
    357. <span class="p">])</span>
    358. <span class="n">transform_val</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    359. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    360. <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]),</span>
    361. <span class="p">])</span>
    362. <span class="n">dataset_cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s2">&quot;class&quot;</span><span class="p">]</span>
    363. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    364. <span class="n">transform</span><span class="o">=</span><span class="n">transform_train</span><span class="p">)</span>
    365. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    366. <span class="n">transform</span><span class="o">=</span><span class="n">transform_val</span><span class="p">)</span></div>
    367. <div class="viewcode-block" id="Cifar10DatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.Cifar10DatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">Cifar10DatasetInterface</span><span class="p">(</span><span class="n">LibraryDatasetInterface</span><span class="p">):</span>
    368. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
    369. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar10DatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;cifar10&quot;</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span></div>
    370. <div class="viewcode-block" id="Cifar100DatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.Cifar100DatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">Cifar100DatasetInterface</span><span class="p">(</span><span class="n">LibraryDatasetInterface</span><span class="p">):</span>
    371. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
    372. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar100DatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;cifar100&quot;</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span></div>
    373. <div class="viewcode-block" id="TestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">TestDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    374. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    375. <span class="nb">super</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
    376. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">trainset</span>
    377. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span>
    378. <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span>
    379. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
    380. <div class="viewcode-block" id="TestDatasetInterface.get_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TestDatasetInterface.get_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">get_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    381. <span class="n">distributed_sampler</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    382. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span>
    383. <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_data_loaders</span><span class="p">(</span><span class="n">batch_size_factor</span><span class="o">=</span><span class="n">batch_size_factor</span><span class="p">,</span>
    384. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    385. <span class="n">train_batch_size</span><span class="o">=</span><span class="n">train_batch_size</span><span class="p">,</span>
    386. <span class="n">val_batch_size</span><span class="o">=</span><span class="n">val_batch_size</span><span class="p">,</span>
    387. <span class="n">distributed_sampler</span><span class="o">=</span><span class="n">distributed_sampler</span><span class="p">)</span></div></div>
    388. <div class="viewcode-block" id="ClassificationTestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ClassificationTestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ClassificationTestDatasetInterface</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">):</span>
    389. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">image_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    390. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))),</span>
    391. <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">))))</span>
    392. <span class="nb">super</span><span class="p">(</span><span class="n">ClassificationTestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">trainset</span><span class="o">=</span><span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    393. <span class="n">classes</span><span class="o">=</span><span class="n">classes</span><span class="p">)</span></div>
    394. <div class="viewcode-block" id="SegmentationTestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.SegmentationTestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">SegmentationTestDatasetInterface</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">):</span>
    395. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">image_size</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">):</span>
    396. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))),</span>
    397. <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))))</span>
    398. <span class="nb">super</span><span class="p">(</span><span class="n">SegmentationTestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">trainset</span><span class="o">=</span><span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span></div>
    399. <div class="viewcode-block" id="DetectionTestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DetectionTestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">DetectionTestDatasetInterface</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">):</span>
    400. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">image_size</span><span class="o">=</span><span class="mi">320</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    401. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))),</span>
    402. <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">6</span><span class="p">))))</span>
    403. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionTestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">trainset</span><span class="o">=</span><span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    404. <span class="n">classes</span><span class="o">=</span><span class="n">classes</span><span class="p">)</span></div>
    405. <div class="viewcode-block" id="TestYoloDetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TestYoloDetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">TestYoloDetectionDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    406. <span class="sd">&quot;&quot;&quot;</span>
    407. <span class="sd"> note: the output size is (batch_size, 6) in the test while in real training</span>
    408. <span class="sd"> the size of axis 0 can vary (the number of bounding boxes)</span>
    409. <span class="sd"> &quot;&quot;&quot;</span>
    410. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">input_dims</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span>
    411. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
    412. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">*</span><span class="n">input_dims</span><span class="p">)),</span>
    413. <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">6</span><span class="p">)))</span>
    414. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
    415. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span></div>
    416. <div class="viewcode-block" id="ImageNetDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ImageNetDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ImageNetDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    417. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">data_dir</span><span class="o">=</span><span class="s2">&quot;/data/Imagenet&quot;</span><span class="p">):</span>
    418. <span class="nb">super</span><span class="p">(</span><span class="n">ImageNetDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
    419. <span class="n">data_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="n">data_dir</span>
    420. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
    421. <span class="n">valdir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">)</span>
    422. <span class="n">img_mean</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_mean&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.485</span><span class="p">,</span> <span class="mf">0.456</span><span class="p">,</span> <span class="mf">0.406</span><span class="p">])</span>
    423. <span class="n">img_std</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_std&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.229</span><span class="p">,</span> <span class="mf">0.224</span><span class="p">,</span> <span class="mf">0.225</span><span class="p">])</span>
    424. <span class="n">normalize</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="n">img_mean</span><span class="p">,</span>
    425. <span class="n">std</span><span class="o">=</span><span class="n">img_std</span><span class="p">)</span>
    426. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">224</span><span class="p">)</span>
    427. <span class="n">resize_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;resize_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span>
    428. <span class="n">color_jitter</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;color_jitter&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
    429. <span class="n">imagenet_pca_aug</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;imagenet_pca_aug&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
    430. <span class="n">train_interpolation</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_interpolation&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="s1">&#39;default&#39;</span><span class="p">)</span>
    431. <span class="n">rand_augment_config_string</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;rand_augment_config_string&#39;</span><span class="p">,</span>
    432. <span class="n">default_val</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
    433. <span class="n">color_jitter</span> <span class="o">=</span> <span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">color_jitter</span><span class="p">),)</span> <span class="o">*</span> <span class="mi">3</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">color_jitter</span><span class="p">,</span> <span class="nb">float</span><span class="p">)</span> <span class="k">else</span> <span class="n">color_jitter</span>
    434. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">color_jitter</span><span class="p">)</span> <span class="ow">in</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="s2">&quot;color_jitter must be a scalar or tuple of len 3 or 4&quot;</span>
    435. <span class="n">color_augmentation</span> <span class="o">=</span> <span class="n">datasets_utils</span><span class="o">.</span><span class="n">get_color_augmentation</span><span class="p">(</span><span class="n">rand_augment_config_string</span><span class="p">,</span> <span class="n">color_jitter</span><span class="p">,</span>
    436. <span class="n">crop_size</span><span class="o">=</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">img_mean</span><span class="o">=</span><span class="n">img_mean</span><span class="p">)</span>
    437. <span class="n">train_transformation_list</span> <span class="o">=</span> <span class="p">[</span>
    438. <span class="n">RandomResizedCropAndInterpolation</span><span class="p">(</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">train_interpolation</span><span class="p">),</span>
    439. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
    440. <span class="n">color_augmentation</span><span class="p">,</span>
    441. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    442. <span class="n">Lighting</span><span class="p">(</span><span class="n">imagenet_pca_aug</span><span class="p">),</span>
    443. <span class="n">normalize</span><span class="p">]</span>
    444. <span class="n">rndm_erase_prob</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;random_erase_prob&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
    445. <span class="k">if</span> <span class="n">rndm_erase_prob</span><span class="p">:</span>
    446. <span class="n">train_transformation_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">RandomErase</span><span class="p">(</span><span class="n">rndm_erase_prob</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">random_erase_value</span><span class="p">))</span>
    447. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">traindir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">(</span><span class="n">train_transformation_list</span><span class="p">))</span>
    448. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">valdir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    449. <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">resize_size</span><span class="p">),</span>
    450. <span class="n">transforms</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">),</span>
    451. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    452. <span class="n">normalize</span><span class="p">,</span>
    453. <span class="p">]))</span></div>
    454. <div class="viewcode-block" id="TinyImageNetDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TinyImageNetDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">TinyImageNetDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    455. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">data_dir</span><span class="o">=</span><span class="s2">&quot;/data/TinyImagenet&quot;</span><span class="p">):</span>
    456. <span class="nb">super</span><span class="p">(</span><span class="n">TinyImageNetDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
    457. <span class="n">data_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="n">data_dir</span>
    458. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
    459. <span class="n">valdir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">)</span>
    460. <span class="n">img_mean</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_mean&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.4802</span><span class="p">,</span> <span class="mf">0.4481</span><span class="p">,</span> <span class="mf">0.3975</span><span class="p">])</span>
    461. <span class="n">img_std</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_std&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.2770</span><span class="p">,</span> <span class="mf">0.2691</span><span class="p">,</span> <span class="mf">0.2821</span><span class="p">])</span>
    462. <span class="n">normalize</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="n">img_mean</span><span class="p">,</span>
    463. <span class="n">std</span><span class="o">=</span><span class="n">img_std</span><span class="p">)</span>
    464. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">56</span><span class="p">)</span>
    465. <span class="n">resize_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;resize_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
    466. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span>
    467. <span class="n">traindir</span><span class="p">,</span>
    468. <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    469. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomResizedCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">),</span>
    470. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
    471. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    472. <span class="n">normalize</span><span class="p">,</span>
    473. <span class="p">]))</span>
    474. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">valdir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    475. <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">resize_size</span><span class="p">),</span>
    476. <span class="n">transforms</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">),</span>
    477. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    478. <span class="n">normalize</span><span class="p">,</span>
    479. <span class="p">]))</span></div>
    480. <div class="viewcode-block" id="ClassificationDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ClassificationDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ClassificationDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    481. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalization_mean</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">normalization_std</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">resolution</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
    482. <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
    483. <span class="nb">super</span><span class="p">(</span><span class="n">ClassificationDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
    484. <span class="n">data_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span>
    485. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
    486. <span class="n">valdir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">)</span>
    487. <span class="n">normalize</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="n">normalization_mean</span><span class="p">,</span>
    488. <span class="n">std</span><span class="o">=</span><span class="n">normalization_std</span><span class="p">)</span>
    489. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span>
    490. <span class="n">traindir</span><span class="p">,</span>
    491. <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    492. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomResizedCrop</span><span class="p">(</span><span class="n">resolution</span><span class="p">),</span>
    493. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
    494. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    495. <span class="n">normalize</span><span class="p">,</span>
    496. <span class="p">]))</span>
    497. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">valdir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    498. <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">resolution</span> <span class="o">*</span> <span class="mf">1.15</span><span class="p">)),</span>
    499. <span class="n">transforms</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="n">resolution</span><span class="p">),</span>
    500. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    501. <span class="n">normalize</span><span class="p">,</span>
    502. <span class="p">]))</span>
    503. <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="n">data_dir</span>
    504. <span class="bp">self</span><span class="o">.</span><span class="n">normalization_mean</span> <span class="o">=</span> <span class="n">normalization_mean</span>
    505. <span class="bp">self</span><span class="o">.</span><span class="n">normalization_std</span> <span class="o">=</span> <span class="n">normalization_std</span></div>
    506. <div class="viewcode-block" id="PascalVOC2012SegmentationDataSetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.PascalVOC2012SegmentationDataSetInterface">[docs]</a><span class="k">class</span> <span class="nc">PascalVOC2012SegmentationDataSetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    507. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    508. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    509. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    510. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    511. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> \
    512. <span class="k">else</span> <span class="s1">&#39;/data/pascal_voc_2012/VOCdevkit/VOC2012/&#39;</span>
    513. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
    514. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;ImageSets/Segmentation/train.txt&#39;</span><span class="p">,</span>
    515. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;JPEGImages&#39;</span><span class="p">,</span>
    516. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;SegmentationClass&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    517. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    518. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
    519. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
    520. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;ImageSets/Segmentation/val.txt&#39;</span><span class="p">,</span>
    521. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;JPEGImages&#39;</span><span class="p">,</span>
    522. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;SegmentationClass&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    523. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    524. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
    525. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
    526. <div class="viewcode-block" id="PascalAUG2012SegmentationDataSetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.PascalAUG2012SegmentationDataSetInterface">[docs]</a><span class="k">class</span> <span class="nc">PascalAUG2012SegmentationDataSetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    527. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    528. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    529. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    530. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    531. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> \
    532. <span class="k">else</span> <span class="s1">&#39;/data/pascal_voc_2012/VOCaug/dataset/&#39;</span>
    533. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">(</span>
    534. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
    535. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;trainval.txt&#39;</span><span class="p">,</span>
    536. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;img&#39;</span><span class="p">,</span>
    537. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;cls&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    538. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    539. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
    540. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">(</span>
    541. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
    542. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;val.txt&#39;</span><span class="p">,</span>
    543. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;img&#39;</span><span class="p">,</span>
    544. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;cls&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    545. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    546. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
    547. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
    548. <div class="viewcode-block" id="CoCoDataSetInterfaceBase"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CoCoDataSetInterfaceBase">[docs]</a><span class="k">class</span> <span class="nc">CoCoDataSetInterfaceBase</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    549. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    550. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    551. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    552. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    553. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;/data/coco/&#39;</span></div>
    554. <div class="viewcode-block" id="CoCoSegmentationDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CoCoSegmentationDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">CoCoSegmentationDatasetInterface</span><span class="p">(</span><span class="n">CoCoDataSetInterfaceBase</span><span class="p">):</span>
    555. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    556. <span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    557. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    558. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">CoCoSegmentationDataSet</span><span class="p">(</span>
    559. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
    560. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;instances_train2017.json&#39;</span><span class="p">,</span>
    561. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/train2017&#39;</span><span class="p">,</span>
    562. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;annotations&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    563. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    564. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    565. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
    566. <span class="n">dataset_classes_inclusion_tuples_list</span><span class="o">=</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
    567. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">CoCoSegmentationDataSet</span><span class="p">(</span>
    568. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
    569. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;instances_val2017.json&#39;</span><span class="p">,</span>
    570. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/val2017&#39;</span><span class="p">,</span>
    571. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;annotations&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    572. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    573. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    574. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
    575. <span class="n">dataset_classes_inclusion_tuples_list</span><span class="o">=</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
    576. <span class="bp">self</span><span class="o">.</span><span class="n">coco_classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
    577. <div class="viewcode-block" id="CityscapesDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CityscapesDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">CityscapesDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    578. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    579. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    580. <span class="n">root_dir</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;dataset_dir&quot;</span><span class="p">,</span> <span class="s2">&quot;/data/cityscapes&quot;</span><span class="p">)</span>
    581. <span class="n">img_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;img_size&quot;</span><span class="p">,</span> <span class="mi">1024</span><span class="p">)</span>
    582. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;crop_size&quot;</span><span class="p">,</span> <span class="mi">512</span><span class="p">)</span>
    583. <span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms&quot;</span><span class="p">)</span>
    584. <span class="n">image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">)</span>
    585. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">CityscapesDataset</span><span class="p">(</span>
    586. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
    587. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;lists/train.lst&#39;</span><span class="p">,</span>
    588. <span class="n">labels_csv_path</span><span class="o">=</span><span class="s2">&quot;lists/labels.csv&quot;</span><span class="p">,</span>
    589. <span class="n">img_size</span><span class="o">=</span><span class="n">img_size</span><span class="p">,</span>
    590. <span class="n">crop_size</span><span class="o">=</span><span class="n">crop_size</span><span class="p">,</span>
    591. <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    592. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    593. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    594. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
    595. <span class="n">image_mask_transforms</span><span class="o">=</span><span class="n">image_mask_transforms</span><span class="p">,</span>
    596. <span class="n">image_mask_transforms_aug</span><span class="o">=</span><span class="n">image_mask_transforms_aug</span><span class="p">)</span>
    597. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">CityscapesDataset</span><span class="p">(</span>
    598. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
    599. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;lists/val.lst&#39;</span><span class="p">,</span>
    600. <span class="n">labels_csv_path</span><span class="o">=</span><span class="s2">&quot;lists/labels.csv&quot;</span><span class="p">,</span>
    601. <span class="n">img_size</span><span class="o">=</span><span class="n">img_size</span><span class="p">,</span>
    602. <span class="n">crop_size</span><span class="o">=</span><span class="n">crop_size</span><span class="p">,</span>
    603. <span class="n">augment</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    604. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    605. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    606. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
    607. <span class="n">image_mask_transforms</span><span class="o">=</span><span class="n">image_mask_transforms</span><span class="p">)</span>
    608. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
    609. <div class="viewcode-block" id="SuperviselyPersonsDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.SuperviselyPersonsDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">SuperviselyPersonsDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    610. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    611. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    612. <span class="n">root_dir</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;dataset_dir&quot;</span><span class="p">,</span> <span class="s2">&quot;/data/supervisely-persons&quot;</span><span class="p">)</span>
    613. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">SuperviselyPersonsDataset</span><span class="p">(</span>
    614. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
    615. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;train.csv&#39;</span><span class="p">,</span>
    616. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    617. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    618. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
    619. <span class="n">image_mask_transforms_aug</span><span class="o">=</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([])),</span>
    620. <span class="n">augment</span><span class="o">=</span><span class="kc">True</span>
    621. <span class="p">)</span>
    622. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">SuperviselyPersonsDataset</span><span class="p">(</span>
    623. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
    624. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;val.csv&#39;</span><span class="p">,</span>
    625. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
    626. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
    627. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
    628. <span class="n">image_mask_transforms</span><span class="o">=</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms&quot;</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([])),</span>
    629. <span class="n">augment</span><span class="o">=</span><span class="kc">False</span>
    630. <span class="p">)</span>
    631. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
    632. <div class="viewcode-block" id="DetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">DetectionDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
    633. <div class="viewcode-block" id="DetectionDatasetInterface.build_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DetectionDatasetInterface.build_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">build_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    634. <span class="n">test_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">distributed_sampler</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    635. <span class="n">train_sampler</span> <span class="o">=</span> <span class="n">InfiniteSampler</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    636. <span class="n">train_batch_sampler</span> <span class="o">=</span> <span class="n">BatchSampler</span><span class="p">(</span>
    637. <span class="n">sampler</span><span class="o">=</span><span class="n">train_sampler</span><span class="p">,</span>
    638. <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
    639. <span class="n">drop_last</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    640. <span class="p">)</span>
    641. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span>
    642. <span class="n">batch_sampler</span><span class="o">=</span><span class="n">train_batch_sampler</span><span class="p">,</span>
    643. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    644. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    645. <span class="n">worker_init_fn</span><span class="o">=</span><span class="n">worker_init_reset_seed</span><span class="p">,</span>
    646. <span class="n">collate_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_collate_fn</span><span class="p">)</span>
    647. <span class="k">if</span> <span class="n">distributed_sampler</span><span class="p">:</span>
    648. <span class="n">sampler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    649. <span class="k">else</span><span class="p">:</span>
    650. <span class="n">sampler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">SequentialSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">)</span>
    651. <span class="n">val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span>
    652. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
    653. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    654. <span class="n">sampler</span><span class="o">=</span><span class="n">sampler</span><span class="p">,</span>
    655. <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_batch_size</span><span class="p">,</span>
    656. <span class="n">collate_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_collate_fn</span><span class="p">)</span>
    657. <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="o">=</span> <span class="n">val_loader</span></div></div>
    658. <div class="viewcode-block" id="PascalVOCUnifiedDetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.PascalVOCUnifiedDetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCUnifiedDetectionDatasetInterface</span><span class="p">(</span><span class="n">DetectionDatasetInterface</span><span class="p">):</span>
    659. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    660. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    661. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    662. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    663. <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">data_dir</span>
    664. <span class="n">train_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">)</span>
    665. <span class="n">val_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">)</span>
    666. <span class="n">train_max_num_samples</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;train_max_num_samples&quot;</span><span class="p">)</span>
    667. <span class="n">val_max_num_samples</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;val_max_num_samples&quot;</span><span class="p">)</span>
    668. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">download</span><span class="p">:</span>
    669. <span class="n">PascalVOCDetectionDataset</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">)</span>
    670. <span class="n">train_dataset_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;train2007&quot;</span><span class="p">,</span> <span class="s2">&quot;val2007&quot;</span><span class="p">,</span> <span class="s2">&quot;train2012&quot;</span><span class="p">,</span> <span class="s2">&quot;val2012&quot;</span><span class="p">]</span>
    671. <span class="c1"># We divide train_max_num_samples between the datasets</span>
    672. <span class="k">if</span> <span class="n">train_max_num_samples</span><span class="p">:</span>
    673. <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">segment</span><span class="p">)</span> <span class="k">for</span> <span class="n">segment</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">train_max_num_samples</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">))]</span>
    674. <span class="k">else</span><span class="p">:</span>
    675. <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)</span>
    676. <span class="n">train_sets</span> <span class="o">=</span> <span class="p">[</span><span class="n">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
    677. <span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
    678. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_train_images</span><span class="p">,</span>
    679. <span class="n">cache_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir</span> <span class="o">+</span> <span class="s2">&quot;cache_train&quot;</span><span class="p">,</span>
    680. <span class="n">transforms</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_transforms</span><span class="p">,</span>
    681. <span class="n">images_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/&#39;</span> <span class="o">+</span> <span class="n">trainset_name</span> <span class="o">+</span> <span class="s1">&#39;/&#39;</span><span class="p">,</span>
    682. <span class="n">class_inclusion_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">class_inclusion_list</span><span class="p">,</span>
    683. <span class="n">max_num_samples</span><span class="o">=</span><span class="n">max_num_samples_per_train_dataset</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
    684. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">trainset_name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)]</span>
    685. <span class="n">testset2007</span> <span class="o">=</span> <span class="n">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
    686. <span class="n">input_dim</span><span class="o">=</span><span class="n">val_input_dim</span><span class="p">,</span>
    687. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_val_images</span><span class="p">,</span>
    688. <span class="n">cache_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir</span> <span class="o">+</span> <span class="s2">&quot;cache_valid&quot;</span><span class="p">,</span>
    689. <span class="n">transforms</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_transforms</span><span class="p">,</span>
    690. <span class="n">images_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/test2007/&#39;</span><span class="p">,</span>
    691. <span class="n">class_inclusion_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">class_inclusion_list</span><span class="p">,</span>
    692. <span class="n">max_num_samples</span><span class="o">=</span><span class="n">val_max_num_samples</span><span class="p">)</span>
    693. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">train_sets</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">classes</span>
    694. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">ConcatDataset</span><span class="p">(</span><span class="n">train_sets</span><span class="p">)</span>
    695. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">testset2007</span>
    696. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">collate_fn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_collate_fn</span>
    697. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span>
    698. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">img_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span>
    699. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">cache_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_train_images</span></div>
    700. <div class="viewcode-block" id="CoCoDetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CoCoDetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">CoCoDetectionDatasetInterface</span><span class="p">(</span><span class="n">DetectionDatasetInterface</span><span class="p">):</span>
    701. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
    702. <span class="nb">super</span><span class="p">(</span><span class="n">CoCoDetectionDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
    703. <span class="n">train_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">)</span>
    704. <span class="n">targets_format</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;targets_format&quot;</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">LABEL_CXCYWH</span><span class="p">)</span>
    705. <span class="n">train_transforms</span> <span class="o">=</span> <span class="p">[</span><span class="n">DetectionMosaic</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
    706. <span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mosaic_prob</span><span class="p">),</span>
    707. <span class="n">DetectionRandomAffine</span><span class="p">(</span><span class="n">degrees</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">degrees</span><span class="p">,</span>
    708. <span class="n">translate</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">translate</span><span class="p">,</span>
    709. <span class="n">scales</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mosaic_scale</span><span class="p">,</span>
    710. <span class="n">shear</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">shear</span><span class="p">,</span>
    711. <span class="n">target_size</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
    712. <span class="n">filter_box_candidates</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">filter_box_candidates</span><span class="p">,</span>
    713. <span class="n">wh_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">wh_thr</span><span class="p">,</span>
    714. <span class="n">area_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">area_thr</span><span class="p">,</span>
    715. <span class="n">ar_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">ar_thr</span>
    716. <span class="p">),</span>
    717. <span class="n">DetectionMixup</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
    718. <span class="n">mixup_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mixup_scale</span><span class="p">,</span>
    719. <span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mixup_prob</span><span class="p">,</span>
    720. <span class="n">flip_prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">flip_prob</span><span class="p">),</span>
    721. <span class="n">DetectionHSV</span><span class="p">(</span><span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">hsv_prob</span><span class="p">,</span>
    722. <span class="n">hgain</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">hgain</span><span class="p">,</span>
    723. <span class="n">sgain</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">sgain</span><span class="p">,</span>
    724. <span class="n">vgain</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">vgain</span>
    725. <span class="p">),</span>
    726. <span class="n">DetectionHorizontalFlip</span><span class="p">(</span><span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">flip_prob</span><span class="p">),</span>
    727. <span class="n">DetectionPaddedRescale</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span> <span class="n">max_targets</span><span class="o">=</span><span class="mi">120</span><span class="p">),</span>
    728. <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">output_format</span><span class="o">=</span><span class="n">targets_format</span><span class="p">)</span>
    729. <span class="p">]</span>
    730. <span class="c1"># IF CACHE- CREATING THE CACHE FILE WILL HAPPEN ONLY FOR RANK 0, THEN ALL THE OTHER RANKS SIMPLY READ FROM IT.</span>
    731. <span class="n">local_rank</span> <span class="o">=</span> <span class="n">get_local_rank</span><span class="p">()</span>
    732. <span class="k">with</span> <span class="n">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">):</span>
    733. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">COCODetectionDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
    734. <span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_subdir</span><span class="p">,</span>
    735. <span class="n">json_file</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_json_file</span><span class="p">,</span>
    736. <span class="n">img_size</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
    737. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_train_images</span><span class="p">,</span>
    738. <span class="n">cache_dir_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir_path</span><span class="p">,</span>
    739. <span class="n">transforms</span><span class="o">=</span><span class="n">train_transforms</span><span class="p">,</span>
    740. <span class="n">with_crowd</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    741. <span class="n">val_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">)</span>
    742. <span class="n">with_crowd</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;with_crowd&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    743. <span class="c1"># IF CACHE- CREATING THE CACHE FILE WILL HAPPEN ONLY FOR RANK 0, THEN ALL THE OTHER RANKS SIMPLY READ FROM IT.</span>
    744. <span class="k">with</span> <span class="n">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">):</span>
    745. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">COCODetectionDataset</span><span class="p">(</span>
    746. <span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
    747. <span class="n">json_file</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_json_file</span><span class="p">,</span>
    748. <span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_subdir</span><span class="p">,</span>
    749. <span class="n">img_size</span><span class="o">=</span><span class="n">val_input_dim</span><span class="p">,</span>
    750. <span class="n">transforms</span><span class="o">=</span><span class="p">[</span><span class="n">DetectionPaddedRescale</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">val_input_dim</span><span class="p">),</span>
    751. <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">max_targets</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">output_format</span><span class="o">=</span><span class="n">targets_format</span><span class="p">)],</span>
    752. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_val_images</span><span class="p">,</span>
    753. <span class="n">cache_dir_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir_path</span><span class="p">,</span>
    754. <span class="n">with_crowd</span><span class="o">=</span><span class="n">with_crowd</span><span class="p">)</span>
    755. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span></div>
    756. </pre></div>
    757. </div>
    758. </div>
    759. <footer>
    760. <hr/>
    761. <div role="contentinfo">
    762. <p>&#169; Copyright 2021, SuperGradients team.</p>
    763. </div>
    764. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    765. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    766. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    767. </footer>
    768. </div>
    769. </div>
    770. </section>
    771. </div>
    772. <script>
    773. jQuery(function () {
    774. SphinxRtdTheme.Navigation.enable(true);
    775. });
    776. </script>
    777. </body>
    778. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    559
    560
    561
    562
    563
    564
    565
    566
    567
    568
    569
    570
    571
    572
    573
    574
    575
    576
    577
    578
    579
    580
    581
    582
    583
    584
    585
    586
    587
    588
    589
    590
    591
    592
    593
    594
    595
    596
    597
    598
    599
    600
    601
    602
    603
    604
    605
    606
    607
    608
    609
    610
    611
    612
    613
    614
    615
    616
    617
    618
    619
    620
    621
    622
    623
    624
    625
    626
    627
    628
    629
    630
    631
    632
    633
    634
    635
    636
    637
    638
    639
    640
    641
    642
    643
    644
    645
    646
    647
    648
    649
    650
    651
    652
    653
    654
    655
    656
    657
    658
    659
    660
    661
    662
    663
    664
    665
    666
    667
    668
    669
    670
    671
    672
    673
    674
    675
    676
    677
    678
    679
    680
    681
    682
    683
    684
    685
    686
    687
    688
    689
    690
    691
    692
    693
    694
    695
    696
    697
    698
    699
    700
    701
    702
    703
    704
    705
    706
    707
    708
    709
    710
    711
    712
    713
    714
    715
    716
    717
    718
    719
    720
    721
    722
    723
    724
    725
    726
    727
    728
    729
    730
    731
    732
    733
    734
    735
    736
    737
    738
    739
    740
    741
    742
    743
    744
    745
    746
    747
    748
    749
    750
    751
    752
    753
    754
    755
    756
    757
    758
    759
    760
    761
    762
    763
    764
    765
    766
    767
    768
    769
    770
    771
    772
    773
    774
    775
    776
    777
    778
    779
    780
    781
    782
    783
    784
    785
    786
    787
    788
    789
    790
    791
    792
    793
    794
    795
    796
    797
    798
    799
    800
    801
    802
    803
    804
    805
    806
    807
    808
    809
    810
    811
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.datasets_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.datasets.datasets_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.datasets.datasets_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">copy</span>
    84. <span class="kn">import</span> <span class="nn">os</span>
    85. <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
    86. <span class="kn">from</span> <span class="nn">multiprocessing</span> <span class="kn">import</span> <span class="n">Value</span><span class="p">,</span> <span class="n">Lock</span>
    87. <span class="kn">import</span> <span class="nn">random</span>
    88. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
    89. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    90. <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
    91. <span class="kn">import</span> <span class="nn">torchvision</span>
    92. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
    93. <span class="kn">import</span> <span class="nn">torch</span>
    94. <span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
    95. <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.abstract_sg_logger</span> <span class="kn">import</span> <span class="n">AbstractSGLogger</span>
    96. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    97. <span class="kn">from</span> <span class="nn">deprecated</span> <span class="kn">import</span> <span class="n">deprecated</span>
    98. <span class="kn">from</span> <span class="nn">matplotlib.patches</span> <span class="kn">import</span> <span class="n">Rectangle</span>
    99. <span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">ImageFolder</span>
    100. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.auto_augment</span> <span class="kn">import</span> <span class="n">rand_augment_transform</span>
    101. <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">transforms</span><span class="p">,</span> <span class="n">InterpolationMode</span><span class="p">,</span> <span class="n">RandomResizedCrop</span>
    102. <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
    103. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">AverageMeter</span>
    104. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionVisualization</span><span class="p">,</span> <span class="n">Anchors</span>
    105. <span class="kn">import</span> <span class="nn">uuid</span>
    106. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="n">get_local_rank</span><span class="p">,</span> <span class="n">get_world_size</span>
    107. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
    108. <div class="viewcode-block" id="get_mean_and_std_torch"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.get_mean_and_std_torch">[docs]</a><span class="k">def</span> <span class="nf">get_mean_and_std_torch</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">RandomResizeSize</span><span class="o">=</span><span class="mi">224</span><span class="p">):</span>
    109. <span class="sd">&quot;&quot;&quot;</span>
    110. <span class="sd"> A function for getting the mean and std of large datasets using pytorch dataloader and gpu functionality.</span>
    111. <span class="sd"> :param data_dir: String, path to none-library dataset folder. For example &quot;/data/Imagenette&quot; or &quot;/data/TinyImagenet&quot;</span>
    112. <span class="sd"> :param dataloader: a torch DataLoader, as it would feed the data into the trainer (including transforms etc).</span>
    113. <span class="sd"> :param RandomResizeSize: Int, the size of the RandomResizeCrop as it appears in the DataInterface (for example, for Imagenet,</span>
    114. <span class="sd"> this value should be 224).</span>
    115. <span class="sd"> :return: 2 lists,mean and std, each one of len 3 (1 for each channel)</span>
    116. <span class="sd"> &quot;&quot;&quot;</span>
    117. <span class="k">assert</span> <span class="n">data_dir</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">dataloader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;Please provide either path to data folder or DataLoader, not both.&#39;</span>
    118. <span class="k">if</span> <span class="n">dataloader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    119. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
    120. <span class="n">trainset</span> <span class="o">=</span> <span class="n">ImageFolder</span><span class="p">(</span><span class="n">traindir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">transforms</span><span class="o">.</span><span class="n">RandomResizedCrop</span><span class="p">(</span><span class="n">RandomResizeSize</span><span class="p">),</span>
    121. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
    122. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">()]))</span>
    123. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">trainset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">)</span>
    124. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Calculating on </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">targets</span><span class="p">)</span><span class="si">}</span><span class="s1"> Training Samples&#39;</span><span class="p">)</span>
    125. <span class="n">device</span> <span class="o">=</span> <span class="s1">&#39;cuda:0&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span>
    126. <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
    127. <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
    128. <span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    129. <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    130. <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">inputs</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
    131. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Min: </span><span class="si">{</span><span class="n">inputs</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="si">}</span><span class="s1">, Max: </span><span class="si">{</span><span class="n">inputs</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    132. <span class="n">chsum</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    133. <span class="k">else</span><span class="p">:</span>
    134. <span class="n">chsum</span> <span class="o">+=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    135. <span class="n">mean</span> <span class="o">=</span> <span class="n">chsum</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">trainset</span><span class="p">)</span> <span class="o">/</span> <span class="n">h</span> <span class="o">/</span> <span class="n">w</span>
    136. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;mean: </span><span class="si">{</span><span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    137. <span class="n">chsum</span> <span class="o">=</span> <span class="kc">None</span>
    138. <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
    139. <span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    140. <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    141. <span class="n">chsum</span> <span class="o">=</span> <span class="p">(</span><span class="n">inputs</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    142. <span class="k">else</span><span class="p">:</span>
    143. <span class="n">chsum</span> <span class="o">+=</span> <span class="p">(</span><span class="n">inputs</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    144. <span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">chsum</span> <span class="o">/</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">trainset</span><span class="p">)</span> <span class="o">*</span> <span class="n">h</span> <span class="o">*</span> <span class="n">w</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
    145. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;std: </span><span class="si">{</span><span class="n">std</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    146. <span class="k">return</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">std</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span></div>
    147. <div class="viewcode-block" id="get_mean_and_std"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.get_mean_and_std">[docs]</a><span class="nd">@deprecated</span><span class="p">(</span><span class="n">reason</span><span class="o">=</span><span class="s1">&#39;Use get_mean_and_std_torch() instead. It is faster and more accurate&#39;</span><span class="p">)</span>
    148. <span class="k">def</span> <span class="nf">get_mean_and_std</span><span class="p">(</span><span class="n">dataset</span><span class="p">):</span>
    149. <span class="sd">&#39;&#39;&#39;Compute the mean and std value of dataset.&#39;&#39;&#39;</span>
    150. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    151. <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
    152. <span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
    153. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;==&gt; Computing mean and std..&#39;</span><span class="p">)</span>
    154. <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span>
    155. <span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
    156. <span class="k">if</span> <span class="n">j</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    157. <span class="nb">print</span><span class="p">(</span><span class="n">j</span><span class="p">)</span>
    158. <span class="n">j</span> <span class="o">+=</span> <span class="mi">1</span>
    159. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">):</span>
    160. <span class="n">mean</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">inputs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
    161. <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">inputs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">std</span><span class="p">()</span>
    162. <span class="n">mean</span><span class="o">.</span><span class="n">div_</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
    163. <span class="n">std</span><span class="o">.</span><span class="n">div_</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
    164. <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">std</span></div>
    165. <div class="viewcode-block" id="AbstractCollateFunction"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.AbstractCollateFunction">[docs]</a><span class="k">class</span> <span class="nc">AbstractCollateFunction</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
    166. <span class="sd">&quot;&quot;&quot;</span>
    167. <span class="sd"> A collate function (for torch DataLoader)</span>
    168. <span class="sd"> &quot;&quot;&quot;</span>
    169. <span class="nd">@abstractmethod</span>
    170. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
    171. <span class="k">pass</span></div>
    172. <div class="viewcode-block" id="ComposedCollateFunction"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.ComposedCollateFunction">[docs]</a><span class="k">class</span> <span class="nc">ComposedCollateFunction</span><span class="p">(</span><span class="n">AbstractCollateFunction</span><span class="p">):</span>
    173. <span class="sd">&quot;&quot;&quot;</span>
    174. <span class="sd"> A function (for torch DataLoader) which executes a sequence of sub collate functions</span>
    175. <span class="sd"> &quot;&quot;&quot;</span>
    176. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">functions</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
    177. <span class="bp">self</span><span class="o">.</span><span class="n">functions</span> <span class="o">=</span> <span class="n">functions</span>
    178. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
    179. <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">functions</span><span class="p">:</span>
    180. <span class="n">batch</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
    181. <span class="k">return</span> <span class="n">batch</span></div>
    182. <div class="viewcode-block" id="AtomicInteger"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.AtomicInteger">[docs]</a><span class="k">class</span> <span class="nc">AtomicInteger</span><span class="p">:</span>
    183. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
    184. <span class="bp">self</span><span class="o">.</span><span class="n">_value</span> <span class="o">=</span> <span class="n">Value</span><span class="p">(</span><span class="s1">&#39;i&#39;</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
    185. <span class="k">def</span> <span class="fm">__set__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
    186. <span class="bp">self</span><span class="o">.</span><span class="n">_value</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">value</span>
    187. <span class="k">def</span> <span class="fm">__get__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="n">owner</span><span class="p">):</span>
    188. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_value</span><span class="o">.</span><span class="n">value</span></div>
    189. <div class="viewcode-block" id="MultiScaleCollateFunction"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.MultiScaleCollateFunction">[docs]</a><span class="k">class</span> <span class="nc">MultiScaleCollateFunction</span><span class="p">(</span><span class="n">AbstractCollateFunction</span><span class="p">):</span>
    190. <span class="sd">&quot;&quot;&quot;</span>
    191. <span class="sd"> a collate function to implement multi-scale data augmentation</span>
    192. <span class="sd"> according to https://arxiv.org/pdf/1612.08242.pdf</span>
    193. <span class="sd"> &quot;&quot;&quot;</span>
    194. <span class="n">_counter</span> <span class="o">=</span> <span class="n">AtomicInteger</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    195. <span class="n">_current_size</span> <span class="o">=</span> <span class="n">AtomicInteger</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    196. <span class="n">_lock</span> <span class="o">=</span> <span class="n">Lock</span><span class="p">()</span>
    197. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">target_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">min_image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">max_image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    198. <span class="n">image_size_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
    199. <span class="n">change_frequency</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">):</span>
    200. <span class="sd">&quot;&quot;&quot;</span>
    201. <span class="sd"> set parameters for the multi-scale collate function</span>
    202. <span class="sd"> the possible image sizes are in range [min_image_size, max_image_size] in steps of image_size_steps</span>
    203. <span class="sd"> a new size will be randomly selected every change_frequency calls to the collate_fn()</span>
    204. <span class="sd"> :param target_size: scales will be [0.66 * target_size, 1.5 * target_size]</span>
    205. <span class="sd"> :param min_image_size: the minimum size to scale down to (in pixels)</span>
    206. <span class="sd"> :param max_image_size: the maximum size to scale up to (in pixels)</span>
    207. <span class="sd"> :param image_size_steps: typically, the stride of the net, which defines the possible image</span>
    208. <span class="sd"> size multiplications</span>
    209. <span class="sd"> :param change_frequency:</span>
    210. <span class="sd"> &quot;&quot;&quot;</span>
    211. <span class="k">assert</span> <span class="n">target_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="p">(</span><span class="n">max_image_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">min_image_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> \
    212. <span class="s1">&#39;either target_size or min_image_size and max_image_size has to be set&#39;</span>
    213. <span class="k">assert</span> <span class="n">target_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">max_image_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;target_size and max_image_size cannot be both defined&#39;</span>
    214. <span class="k">if</span> <span class="n">target_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    215. <span class="n">min_image_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.66</span> <span class="o">*</span> <span class="n">target_size</span> <span class="o">-</span> <span class="p">((</span><span class="mf">0.66</span> <span class="o">*</span> <span class="n">target_size</span><span class="p">)</span> <span class="o">%</span> <span class="n">image_size_steps</span><span class="p">)</span> <span class="o">+</span> <span class="n">image_size_steps</span><span class="p">)</span>
    216. <span class="n">max_image_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1.5</span> <span class="o">*</span> <span class="n">target_size</span> <span class="o">-</span> <span class="p">((</span><span class="mf">1.5</span> <span class="o">*</span> <span class="n">target_size</span><span class="p">)</span> <span class="o">%</span> <span class="n">image_size_steps</span><span class="p">))</span>
    217. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Using multi-scale </span><span class="si">%g</span><span class="s1"> - </span><span class="si">%g</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">min_image_size</span><span class="p">,</span> <span class="n">max_image_size</span><span class="p">))</span>
    218. <span class="bp">self</span><span class="o">.</span><span class="n">sizes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">min_image_size</span><span class="p">,</span> <span class="n">max_image_size</span> <span class="o">+</span> <span class="n">image_size_steps</span><span class="p">,</span> <span class="n">image_size_steps</span><span class="p">)</span>
    219. <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">=</span> <span class="n">image_size_steps</span>
    220. <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">=</span> <span class="n">change_frequency</span>
    221. <span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sizes</span><span class="p">)</span>
    222. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
    223. <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">_lock</span><span class="p">:</span>
    224. <span class="c1"># Important: this implementation was tailored for a specific input. it assumes the batch is a tuple where</span>
    225. <span class="c1"># the images are the first item</span>
    226. <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">),</span> <span class="s1">&#39;this collate function expects the input to be a tuple (images, labels)&#39;</span>
    227. <span class="n">images</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    228. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_counter</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    229. <span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sizes</span><span class="p">)</span>
    230. <span class="bp">self</span><span class="o">.</span><span class="n">_counter</span> <span class="o">+=</span> <span class="mi">1</span>
    231. <span class="k">assert</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
    232. <span class="s1">&#39;images sized not divisible by </span><span class="si">%d</span><span class="s1">. (resize images before calling multi_scale)&#39;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span>
    233. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span> <span class="o">!=</span> <span class="nb">max</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]):</span>
    234. <span class="n">ratio</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span><span class="p">)</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:])</span>
    235. <span class="n">new_size</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">)),</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">)))</span>
    236. <span class="n">images</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">interpolate</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">new_size</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;bilinear&#39;</span><span class="p">,</span> <span class="n">align_corners</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    237. <span class="k">return</span> <span class="n">images</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></div>
    238. <div class="viewcode-block" id="AbstractPrePredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.AbstractPrePredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">AbstractPrePredictionCallback</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
    239. <span class="sd">&quot;&quot;&quot;</span>
    240. <span class="sd"> Abstract class for forward pass preprocessing function, to be used by passing its inheritors through training_params</span>
    241. <span class="sd"> pre_prediction_callback keyword arg.</span>
    242. <span class="sd"> Should implement __call__ and return images, targets after applying the desired preprocessing.</span>
    243. <span class="sd"> &quot;&quot;&quot;</span>
    244. <span class="nd">@abstractmethod</span>
    245. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
    246. <span class="k">pass</span></div>
    247. <div class="viewcode-block" id="MultiscalePrePredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.MultiscalePrePredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">MultiscalePrePredictionCallback</span><span class="p">(</span><span class="n">AbstractPrePredictionCallback</span><span class="p">):</span>
    248. <span class="sd">&quot;&quot;&quot;</span>
    249. <span class="sd"> Mutiscale pre-prediction callback pass function.</span>
    250. <span class="sd"> When passed through train_params images, targets will be applied by the below transform to support multi scaling</span>
    251. <span class="sd"> on the fly.</span>
    252. <span class="sd"> After each self.frequency forward passes, change size randomly from</span>
    253. <span class="sd"> (input_size-self.multiscale_range*self.image_size_steps, input_size-(self.multiscale_range-1)*self.image_size_steps,</span>
    254. <span class="sd"> ...input_size+self.multiscale_range*self.image_size_steps)</span>
    255. <span class="sd"> Attributes:</span>
    256. <span class="sd"> multiscale_range: (int) Range of values for resize sizes as discussed above (default=5)</span>
    257. <span class="sd"> image_size_steps: (int) Image step sizes as discussed abov (default=32)</span>
    258. <span class="sd"> change_frequency: (int) The frequency to apply change in input size.</span>
    259. <span class="sd"> &quot;&quot;&quot;</span>
    260. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">multiscale_range</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
    261. <span class="n">image_size_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
    262. <span class="n">change_frequency</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">):</span>
    263. <span class="bp">self</span><span class="o">.</span><span class="n">multiscale_range</span> <span class="o">=</span> <span class="n">multiscale_range</span>
    264. <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">=</span> <span class="n">image_size_steps</span>
    265. <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">=</span> <span class="n">change_frequency</span>
    266. <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="kc">None</span>
    267. <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="kc">None</span>
    268. <span class="bp">self</span><span class="o">.</span><span class="n">sampled_imres_once</span> <span class="o">=</span> <span class="kc">False</span>
    269. <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span> <span class="o">=</span> <span class="kc">None</span>
    270. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
    271. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    272. <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">get_local_rank</span><span class="p">()</span>
    273. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    274. <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="n">get_world_size</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span>
    275. <span class="c1"># GENERATE A NEW SIZE AND BROADCAST IT TO THE THE OTHER RANKS SO THEY HAVE THE SAME SCALE</span>
    276. <span class="n">input_size</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span>
    277. <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    278. <span class="n">tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    279. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    280. <span class="n">size_factor</span> <span class="o">=</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    281. <span class="n">min_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span><span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiscale_range</span>
    282. <span class="n">max_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiscale_range</span>
    283. <span class="n">random_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">min_size</span><span class="p">,</span> <span class="n">max_size</span><span class="p">)</span>
    284. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampled_imres_once</span><span class="p">:</span>
    285. <span class="n">size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="o">*</span><span class="n">random_size</span><span class="p">)</span>
    286. <span class="k">else</span><span class="p">:</span>
    287. <span class="c1"># sample the biggest resolution first to make sure the run fits into the GPU memory</span>
    288. <span class="n">size</span> <span class="o">=</span> <span class="n">max_size</span>
    289. <span class="bp">self</span><span class="o">.</span><span class="n">sampled_imres_once</span> <span class="o">=</span> <span class="kc">True</span>
    290. <span class="n">size</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">*</span> <span class="n">size</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">*</span> <span class="nb">int</span><span class="p">(</span><span class="n">size</span> <span class="o">*</span> <span class="n">size_factor</span><span class="p">))</span>
    291. <span class="n">tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    292. <span class="n">tensor</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    293. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">:</span>
    294. <span class="n">dist</span><span class="o">.</span><span class="n">barrier</span><span class="p">()</span>
    295. <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    296. <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">tensor</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
    297. <span class="n">scale_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    298. <span class="n">scale_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    299. <span class="k">if</span> <span class="n">scale_x</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scale_y</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
    300. <span class="n">inputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">interpolate</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;bilinear&quot;</span><span class="p">,</span> <span class="n">align_corners</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    301. <span class="k">return</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span></div>
    302. <div class="viewcode-block" id="DetectionMultiscalePrePredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.DetectionMultiscalePrePredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">DetectionMultiscalePrePredictionCallback</span><span class="p">(</span><span class="n">MultiscalePrePredictionCallback</span><span class="p">):</span>
    303. <span class="sd">&quot;&quot;&quot;</span>
    304. <span class="sd"> Mutiscalepre-prediction callback for object detection.</span>
    305. <span class="sd"> When passed through train_params images, targets will be applied by the below transform to support multi scaling</span>
    306. <span class="sd"> on the fly.</span>
    307. <span class="sd"> After each self.frequency forward passes, change size randomly from</span>
    308. <span class="sd"> (input_size-self.multiscale_range*self.image_size_steps, input_size-(self.multiscale_range-1)*self.image_size_steps,</span>
    309. <span class="sd"> ...input_size+self.multiscale_range*self.image_size_steps) and apply the same rescaling to the box coordinates.</span>
    310. <span class="sd"> Attributes:</span>
    311. <span class="sd"> multiscale_range: (int) Range of values for resize sizes as discussed above (default=5)</span>
    312. <span class="sd"> image_size_steps: (int) Image step sizes as discussed abov (default=32)</span>
    313. <span class="sd"> change_frequency: (int) The frequency to apply change in input size.</span>
    314. <span class="sd"> &quot;&quot;&quot;</span>
    315. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
    316. <span class="c1"># RESCALE THE IMAGE FIRST WITH SUPER(), AND IF RESCALING HAS ACTUALLY BEEN DONE APPLY TO BOXES AS WELL</span>
    317. <span class="n">input_size</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span>
    318. <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">DetectionMultiscalePrePredictionCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
    319. <span class="n">new_input_size</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span>
    320. <span class="n">scale_y</span> <span class="o">=</span> <span class="n">new_input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    321. <span class="n">scale_x</span> <span class="o">=</span> <span class="n">new_input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    322. <span class="k">if</span> <span class="n">scale_x</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scale_y</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
    323. <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_x</span>
    324. <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">3</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">3</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_y</span>
    325. <span class="k">return</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span></div>
    326. <span class="n">_pil_interpolation_to_str</span> <span class="o">=</span> <span class="p">{</span>
    327. <span class="n">Image</span><span class="o">.</span><span class="n">NEAREST</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.NEAREST&#39;</span><span class="p">,</span>
    328. <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.BILINEAR&#39;</span><span class="p">,</span>
    329. <span class="n">Image</span><span class="o">.</span><span class="n">BICUBIC</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.BICUBIC&#39;</span><span class="p">,</span>
    330. <span class="n">Image</span><span class="o">.</span><span class="n">LANCZOS</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.LANCZOS&#39;</span><span class="p">,</span>
    331. <span class="n">Image</span><span class="o">.</span><span class="n">HAMMING</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.HAMMING&#39;</span><span class="p">,</span>
    332. <span class="n">Image</span><span class="o">.</span><span class="n">BOX</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.BOX&#39;</span><span class="p">,</span>
    333. <span class="p">}</span>
    334. <span class="k">def</span> <span class="nf">_pil_interp</span><span class="p">(</span><span class="n">method</span><span class="p">):</span>
    335. <span class="k">if</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;bicubic&#39;</span><span class="p">:</span>
    336. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BICUBIC</span>
    337. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;lanczos&#39;</span><span class="p">:</span>
    338. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">LANCZOS</span>
    339. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;hamming&#39;</span><span class="p">:</span>
    340. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">HAMMING</span>
    341. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;nearest&#39;</span><span class="p">:</span>
    342. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
    343. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">:</span>
    344. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BILINEAR</span>
    345. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;box&#39;</span><span class="p">:</span>
    346. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BOX</span>
    347. <span class="k">else</span><span class="p">:</span>
    348. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;interpolation type must be one of [&#39;bilinear&#39;, &#39;bicubic&#39;, &#39;lanczos&#39;, &#39;hamming&#39;, &quot;</span>
    349. <span class="s2">&quot;&#39;nearest&#39;, &#39;box&#39;] for explicit interpolation type, or &#39;random&#39; for random&quot;</span><span class="p">)</span>
    350. <span class="n">_RANDOM_INTERPOLATION</span> <span class="o">=</span> <span class="p">(</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BILINEAR</span><span class="p">,</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BICUBIC</span><span class="p">)</span>
    351. <div class="viewcode-block" id="RandomResizedCropAndInterpolation"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.RandomResizedCropAndInterpolation">[docs]</a><span class="k">class</span> <span class="nc">RandomResizedCropAndInterpolation</span><span class="p">(</span><span class="n">RandomResizedCrop</span><span class="p">):</span>
    352. <span class="sd">&quot;&quot;&quot;</span>
    353. <span class="sd"> Crop the given PIL Image to random size and aspect ratio with explicitly chosen or random interpolation.</span>
    354. <span class="sd"> A crop of random size (default: of 0.08 to 1.0) of the original size and a random</span>
    355. <span class="sd"> aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop</span>
    356. <span class="sd"> is finally resized to given size.</span>
    357. <span class="sd"> This is popularly used to train the Inception networks.</span>
    358. <span class="sd"> Args:</span>
    359. <span class="sd"> size: expected output size of each edge</span>
    360. <span class="sd"> scale: range of size of the origin size cropped</span>
    361. <span class="sd"> ratio: range of aspect ratio of the origin aspect ratio cropped</span>
    362. <span class="sd"> interpolation: Default: PIL.Image.BILINEAR</span>
    363. <span class="sd"> &quot;&quot;&quot;</span>
    364. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="p">(</span><span class="mf">0.08</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span> <span class="n">ratio</span><span class="o">=</span><span class="p">(</span><span class="mf">3.</span> <span class="o">/</span> <span class="mf">4.</span><span class="p">,</span> <span class="mf">4.</span> <span class="o">/</span> <span class="mf">3.</span><span class="p">),</span>
    365. <span class="n">interpolation</span><span class="o">=</span><span class="s1">&#39;default&#39;</span><span class="p">):</span>
    366. <span class="nb">super</span><span class="p">(</span><span class="n">RandomResizedCropAndInterpolation</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">,</span> <span class="n">ratio</span><span class="o">=</span><span class="n">ratio</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">interpolation</span><span class="p">)</span>
    367. <span class="k">if</span> <span class="n">interpolation</span> <span class="o">==</span> <span class="s1">&#39;random&#39;</span><span class="p">:</span>
    368. <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span> <span class="o">=</span> <span class="n">_RANDOM_INTERPOLATION</span>
    369. <span class="k">elif</span> <span class="n">interpolation</span> <span class="o">==</span> <span class="s1">&#39;default&#39;</span><span class="p">:</span>
    370. <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span> <span class="o">=</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BILINEAR</span>
    371. <span class="k">else</span><span class="p">:</span>
    372. <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span> <span class="o">=</span> <span class="n">_pil_interp</span><span class="p">(</span><span class="n">interpolation</span><span class="p">)</span>
    373. <div class="viewcode-block" id="RandomResizedCropAndInterpolation.forward"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.RandomResizedCropAndInterpolation.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
    374. <span class="sd">&quot;&quot;&quot;</span>
    375. <span class="sd"> Args:</span>
    376. <span class="sd"> img (PIL Image): Image to be cropped and resized.</span>
    377. <span class="sd"> Returns:</span>
    378. <span class="sd"> PIL Image: Randomly cropped and resized image.</span>
    379. <span class="sd"> &quot;&quot;&quot;</span>
    380. <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_params</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ratio</span><span class="p">)</span>
    381. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
    382. <span class="n">interpolation</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">)</span>
    383. <span class="k">else</span><span class="p">:</span>
    384. <span class="n">interpolation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span>
    385. <span class="k">return</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">resized_crop</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">interpolation</span><span class="p">)</span></div>
    386. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    387. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
    388. <span class="n">interpolate_str</span> <span class="o">=</span> <span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">_pil_interpolation_to_str</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">])</span>
    389. <span class="k">else</span><span class="p">:</span>
    390. <span class="n">interpolate_str</span> <span class="o">=</span> <span class="n">_pil_interpolation_to_str</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">]</span>
    391. <span class="n">format_string</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s1">&#39;(size=</span><span class="si">{0}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
    392. <span class="n">format_string</span> <span class="o">+=</span> <span class="s1">&#39;, scale=</span><span class="si">{0}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="p">))</span>
    393. <span class="n">format_string</span> <span class="o">+=</span> <span class="s1">&#39;, ratio=</span><span class="si">{0}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">r</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ratio</span><span class="p">))</span>
    394. <span class="n">format_string</span> <span class="o">+=</span> <span class="s1">&#39;, interpolation=</span><span class="si">{0}</span><span class="s1">)&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">interpolate_str</span><span class="p">)</span>
    395. <span class="k">return</span> <span class="n">format_string</span></div>
    396. <span class="n">STAT_LOGGER_FONT_SIZE</span> <span class="o">=</span> <span class="mi">15</span>
    397. <div class="viewcode-block" id="DatasetStatisticsTensorboardLogger"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.DatasetStatisticsTensorboardLogger">[docs]</a><span class="k">class</span> <span class="nc">DatasetStatisticsTensorboardLogger</span><span class="p">:</span>
    398. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    399. <span class="n">DEFAULT_SUMMARY_PARAMS</span> <span class="o">=</span> <span class="p">{</span>
    400. <span class="s1">&#39;sample_images&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="c1"># by default, 32 images will be sampled from each dataset</span>
    401. <span class="s1">&#39;plot_class_distribution&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    402. <span class="s1">&#39;plot_box_size_distribution&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    403. <span class="s1">&#39;plot_anchors_coverage&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    404. <span class="s1">&#39;max_batches&#39;</span><span class="p">:</span> <span class="mi">30</span>
    405. <span class="p">}</span>
    406. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sg_logger</span><span class="p">:</span> <span class="n">AbstractSGLogger</span><span class="p">,</span> <span class="n">summary_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="n">DEFAULT_SUMMARY_PARAMS</span><span class="p">):</span>
    407. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="n">sg_logger</span>
    408. <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span> <span class="o">=</span> <span class="p">{</span><span class="o">**</span><span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">DEFAULT_SUMMARY_PARAMS</span><span class="p">,</span> <span class="o">**</span><span class="n">summary_params</span><span class="p">}</span>
    409. <div class="viewcode-block" id="DatasetStatisticsTensorboardLogger.analyze"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.DatasetStatisticsTensorboardLogger.analyze">[docs]</a> <span class="k">def</span> <span class="nf">analyze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_loader</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    410. <span class="n">all_classes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">anchors</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    411. <span class="sd">&quot;&quot;&quot;</span>
    412. <span class="sd"> :param data_loader: the dataset data loader</span>
    413. <span class="sd"> :param dataset_params: the dataset parameters</span>
    414. <span class="sd"> :param title: the title for this dataset (i.e. Coco 2017 test set)</span>
    415. <span class="sd"> :param anchors: the list of anchors used by the model. applicable only for detection datasets</span>
    416. <span class="sd"> :param all_classes: the list of all classes names</span>
    417. <span class="sd"> &quot;&quot;&quot;</span>
    418. <span class="c1"># FIXME: UNCOMMENT AND APPLY TO NEW DetectionDataSet ONCE ITS MERGED</span>
    419. <span class="c1"># if isinstance(data_loader.dataset, DetectionDataSet):</span>
    420. <span class="c1"># self._analyze_detection(data_loader=data_loader, title=title,</span>
    421. <span class="c1"># all_classes=all_classes, anchors=anchors)</span>
    422. <span class="c1"># else:</span>
    423. <span class="c1"># DatasetStatisticsTensorboardLogger.logger.warning(&#39;only DetectionDataSet are currently supported&#39;)</span>
    424. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;only DetectionDataSet are currently supported&#39;</span><span class="p">)</span></div>
    425. <span class="k">def</span> <span class="nf">_analyze_detection</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_loader</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="n">all_classes</span><span class="p">,</span> <span class="n">anchors</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    426. <span class="sd">&quot;&quot;&quot;</span>
    427. <span class="sd"> Analyze a detection dataset</span>
    428. <span class="sd"> :param data_loader: the dataset data loader</span>
    429. <span class="sd"> :param dataset_params: the dataset parameters</span>
    430. <span class="sd"> :param all_classes: the list of all classes names</span>
    431. <span class="sd"> :param title: the title for this dataset (i.e. Coco 2017 test set)</span>
    432. <span class="sd"> :param anchors: the list of anchors used by the model. if not provided, anchors coverage will not be analyzed</span>
    433. <span class="sd"> &quot;&quot;&quot;</span>
    434. <span class="k">try</span><span class="p">:</span>
    435. <span class="n">color_mean</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    436. <span class="n">color_std</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    437. <span class="n">all_labels</span> <span class="o">=</span> <span class="p">[]</span>
    438. <span class="n">image_size</span> <span class="o">=</span> <span class="mi">0</span>
    439. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">tqdm</span><span class="p">(</span><span class="n">data_loader</span><span class="p">)):</span>
    440. <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;max_batches&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    441. <span class="k">break</span>
    442. <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    443. <span class="n">image_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
    444. <span class="k">if</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;sample_images&#39;</span><span class="p">]:</span>
    445. <span class="n">samples</span> <span class="o">=</span> <span class="n">images</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;sample_images&#39;</span><span class="p">]]</span>
    446. <span class="k">else</span><span class="p">:</span>
    447. <span class="n">samples</span> <span class="o">=</span> <span class="n">images</span>
    448. <span class="n">pred</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">samples</span><span class="p">))]</span>
    449. <span class="k">try</span><span class="p">:</span>
    450. <span class="n">result_images</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span><span class="n">image_tensor</span><span class="o">=</span><span class="n">samples</span><span class="p">,</span> <span class="n">pred_boxes</span><span class="o">=</span><span class="n">pred</span><span class="p">,</span>
    451. <span class="n">target_boxes</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">labels</span><span class="p">),</span>
    452. <span class="n">batch_name</span><span class="o">=</span><span class="n">title</span><span class="p">,</span>
    453. <span class="n">class_names</span><span class="o">=</span><span class="n">all_classes</span><span class="p">,</span>
    454. <span class="n">box_thickness</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
    455. <span class="n">gt_alpha</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
    456. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_images</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> sample images&#39;</span><span class="p">,</span> <span class="n">images</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">result_images</span><span class="p">)</span>
    457. <span class="o">.</span><span class="n">transpose</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:])</span>
    458. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    459. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
    460. <span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at adding an example batch:</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    461. <span class="k">return</span>
    462. <span class="n">all_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    463. <span class="n">color_mean</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span>
    464. <span class="n">color_std</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span>
    465. <span class="n">all_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    466. <span class="k">try</span><span class="p">:</span>
    467. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;plot_class_distribution&#39;</span><span class="p">]:</span>
    468. <span class="bp">self</span><span class="o">.</span><span class="n">_analyze_class_distribution</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">all_classes</span><span class="p">),</span> <span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">)</span>
    469. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    470. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at analyzing class distributions.</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    471. <span class="k">return</span>
    472. <span class="k">try</span><span class="p">:</span>
    473. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;plot_box_size_distribution&#39;</span><span class="p">]:</span>
    474. <span class="bp">self</span><span class="o">.</span><span class="n">_analyze_object_size_distribution</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">)</span>
    475. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    476. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at analyzing object size &#39;</span>
    477. <span class="sa">f</span><span class="s1">&#39;distributions.</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    478. <span class="k">return</span>
    479. <span class="n">summary</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
    480. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;dataset size: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">data_loader</span><span class="p">)</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
    481. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;color mean: </span><span class="si">{</span><span class="n">color_mean</span><span class="o">.</span><span class="n">average</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
    482. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;color std: </span><span class="si">{</span><span class="n">color_std</span><span class="o">.</span><span class="n">average</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
    483. <span class="k">try</span><span class="p">:</span>
    484. <span class="k">if</span> <span class="n">anchors</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">image_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    485. <span class="n">coverage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_analyze_anchors_coverage</span><span class="p">(</span><span class="n">anchors</span><span class="o">=</span><span class="n">anchors</span><span class="p">,</span> <span class="n">image_size</span><span class="o">=</span><span class="n">image_size</span><span class="p">,</span>
    486. <span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">all_labels</span><span class="p">)</span>
    487. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;anchors: </span><span class="si">{</span><span class="n">anchors</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
    488. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;anchors coverage: </span><span class="si">{</span><span class="n">coverage</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
    489. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    490. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at analyzing anchors &#39;</span>
    491. <span class="sa">f</span><span class="s1">&#39;coverage.</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    492. <span class="k">return</span>
    493. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_text</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> Statistics&#39;</span><span class="p">,</span> <span class="n">text_string</span><span class="o">=</span><span class="n">summary</span><span class="p">)</span>
    494. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    495. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    496. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;dataset analysis failed!</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
    497. <span class="k">def</span> <span class="nf">_analyze_class_distribution</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    498. <span class="n">hist</span><span class="p">,</span> <span class="n">edges</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">num_classes</span><span class="p">)</span>
    499. <span class="n">f</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span>
    500. <span class="n">plt</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_classes</span><span class="p">),</span> <span class="n">hist</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s1">&#39;#0504aa&#39;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.7</span><span class="p">)</span>
    501. <span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
    502. <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;y&#39;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.75</span><span class="p">)</span>
    503. <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">&#39;Value&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    504. <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">&#39;Frequency&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    505. <span class="n">plt</span><span class="o">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    506. <span class="n">plt</span><span class="o">.</span><span class="n">yticks</span><span class="p">(</span><span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    507. <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> class distribution&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    508. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_figure</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s2"> class distribution&quot;</span><span class="p">,</span> <span class="n">figure</span><span class="o">=</span><span class="n">f</span><span class="p">)</span>
    509. <span class="n">text_dist</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
    510. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">hist</span><span class="p">):</span>
    511. <span class="n">text_dist</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;[</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">]: </span><span class="si">{</span><span class="n">val</span><span class="si">}</span><span class="s1">, &#39;</span>
    512. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_text</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s2"> class distribution&quot;</span><span class="p">,</span> <span class="n">text_string</span><span class="o">=</span><span class="n">text_dist</span><span class="p">)</span>
    513. <span class="k">def</span> <span class="nf">_analyze_object_size_distribution</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    514. <span class="sd">&quot;&quot;&quot;</span>
    515. <span class="sd"> This function will add two plots to the tensorboard.</span>
    516. <span class="sd"> one is a 2D histogram and the other is a scatter plot. in both cases the X axis is the object width and Y axis</span>
    517. <span class="sd"> is the object width (both normalized by image size)</span>
    518. <span class="sd"> :param labels: all the labels of the dataset of the shape [class_label, x_center, y_center, w, h]</span>
    519. <span class="sd"> :param title: the dataset title</span>
    520. <span class="sd"> &quot;&quot;&quot;</span>
    521. <span class="c1"># histogram plot</span>
    522. <span class="n">hist</span><span class="p">,</span> <span class="n">xedges</span><span class="p">,</span> <span class="n">yedges</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram2d</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="mi">50</span><span class="p">)</span> <span class="c1"># x and y are deliberately switched</span>
    523. <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
    524. <span class="n">fig</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> boxes w/h distribution&#39;</span><span class="p">)</span>
    525. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">121</span><span class="p">)</span>
    526. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    527. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    528. <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">hist</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">interpolation</span><span class="o">=</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="s1">&#39;lower&#39;</span><span class="p">,</span>
    529. <span class="n">extent</span><span class="o">=</span><span class="p">[</span><span class="n">xedges</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">xedges</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">yedges</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">yedges</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]])</span>
    530. <span class="c1"># scatter plot</span>
    531. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">10000</span><span class="p">:</span>
    532. <span class="c1"># we randomly sample just 10000 objects so that the scatter plot will not get too dense</span>
    533. <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10000</span><span class="p">)]</span>
    534. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">122</span><span class="p">)</span>
    535. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    536. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    537. <span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">)</span>
    538. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_figure</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> boxes w/h distribution&#39;</span><span class="p">,</span> <span class="n">figure</span><span class="o">=</span><span class="n">fig</span><span class="p">)</span>
    539. <span class="nd">@staticmethod</span>
    540. <span class="k">def</span> <span class="nf">_get_rect</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">):</span>
    541. <span class="n">min_w</span> <span class="o">=</span> <span class="n">w</span> <span class="o">/</span> <span class="mf">4.0</span>
    542. <span class="n">min_h</span> <span class="o">=</span> <span class="n">h</span> <span class="o">/</span> <span class="mf">4.0</span>
    543. <span class="k">return</span> <span class="n">Rectangle</span><span class="p">((</span><span class="n">min_w</span><span class="p">,</span> <span class="n">min_h</span><span class="p">),</span> <span class="n">w</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">-</span> <span class="n">min_w</span><span class="p">,</span> <span class="n">h</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">-</span> <span class="n">min_h</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">&#39;b&#39;</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">)</span>
    544. <span class="nd">@staticmethod</span>
    545. <span class="k">def</span> <span class="nf">_get_score</span><span class="p">(</span><span class="n">anchors</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">points</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    546. <span class="sd">&quot;&quot;&quot;</span>
    547. <span class="sd"> Calculate the ratio (and 1/ratio) between each anchor width and height and each point (representing a possible</span>
    548. <span class="sd"> object width and height).</span>
    549. <span class="sd"> i.e. for an anchor with w=10,h=20 the point w=11,h=25 will have the ratios 11/10=1.1 and 25/20=1.25</span>
    550. <span class="sd"> or 10/11=0.91 and 20/25=0.8 respectively</span>
    551. <span class="sd"> :param anchors: array of anchors of the shape [2,N]</span>
    552. <span class="sd"> :param points: array of points of the shape [2,M]</span>
    553. <span class="sd"> :param image_size the size of the input image</span>
    554. <span class="sd"> :returns: an array of size [image_size - 1, image_size - 1] where each cell i,j represent the minimum ratio</span>
    555. <span class="sd"> for that cell (point) from all anchors</span>
    556. <span class="sd"> &quot;&quot;&quot;</span>
    557. <span class="n">ratio</span> <span class="o">=</span> <span class="n">anchors</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">/</span> <span class="n">points</span><span class="p">[:,</span> <span class="p">]</span>
    558. <span class="n">inv_ratio</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">ratio</span>
    559. <span class="n">min_ratio</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">ratio</span><span class="p">,</span> <span class="n">inv_ratio</span><span class="p">)</span>
    560. <span class="n">min_ratio</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">min_ratio</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    561. <span class="n">to_closest_anchor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">min_ratio</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    562. <span class="n">to_closest_anchor</span><span class="p">[</span><span class="n">to_closest_anchor</span> <span class="o">&gt;</span> <span class="mf">0.75</span><span class="p">]</span> <span class="o">=</span> <span class="mi">2</span>
    563. <span class="k">return</span> <span class="n">to_closest_anchor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">image_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    564. <span class="k">def</span> <span class="nf">_analyze_anchors_coverage</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">anchors</span><span class="p">:</span> <span class="n">Anchors</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    565. <span class="sd">&quot;&quot;&quot;</span>
    566. <span class="sd"> This function will add anchors coverage plots to the tensorboard.</span>
    567. <span class="sd"> :param anchors: a list of anchors</span>
    568. <span class="sd"> :param image_size: the input image size for this training</span>
    569. <span class="sd"> :param labels: all the labels of the dataset of the shape [class_label, x_center, y_center, w, h]</span>
    570. <span class="sd"> :param title: the dataset title</span>
    571. <span class="sd"> &quot;&quot;&quot;</span>
    572. <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
    573. <span class="n">fig</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> anchors coverage&#39;</span><span class="p">)</span>
    574. <span class="c1"># box style plot</span>
    575. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">121</span><span class="p">)</span>
    576. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    577. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    578. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">])</span>
    579. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">])</span>
    580. <span class="n">anchors_boxes</span> <span class="o">=</span> <span class="n">anchors</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    581. <span class="n">anchors_len</span> <span class="o">=</span> <span class="n">anchors</span><span class="o">.</span><span class="n">num_anchors</span>
    582. <span class="n">anchors_boxes</span> <span class="o">=</span> <span class="n">anchors_boxes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    583. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">anchors_len</span><span class="p">):</span>
    584. <span class="n">rect</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_rect</span><span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
    585. <span class="n">rect</span><span class="o">.</span><span class="n">set_alpha</span><span class="p">(</span><span class="mf">0.3</span><span class="p">)</span>
    586. <span class="n">rect</span><span class="o">.</span><span class="n">set_facecolor</span><span class="p">([</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(),</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(),</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(),</span> <span class="mf">0.3</span><span class="p">])</span>
    587. <span class="n">ax</span><span class="o">.</span><span class="n">add_patch</span><span class="p">(</span><span class="n">rect</span><span class="p">)</span>
    588. <span class="c1"># distance from anchor plot</span>
    589. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">122</span><span class="p">)</span>
    590. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    591. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    592. <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    593. <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    594. <span class="n">xx</span><span class="p">,</span> <span class="n">yy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sparse</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    595. <span class="n">points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">xx</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">yy</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)])</span>
    596. <span class="n">color</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_score</span><span class="p">(</span><span class="n">anchors_boxes</span><span class="p">,</span> <span class="n">points</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)</span>
    597. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    598. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
    599. <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">color</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="s1">&#39;lower&#39;</span><span class="p">,</span>
    600. <span class="n">extent</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">])</span>
    601. <span class="c1"># calculate the coverage for the dataset labels</span>
    602. <span class="n">cover_masks</span> <span class="o">=</span> <span class="p">[]</span>
    603. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">anchors_len</span><span class="p">):</span>
    604. <span class="n">w_max</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span>
    605. <span class="n">w_min</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.25</span>
    606. <span class="n">h_max</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span>
    607. <span class="n">h_min</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.25</span>
    608. <span class="n">cover_masks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span>
    609. <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">w_max</span><span class="p">,</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">w_min</span><span class="p">),</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">h_max</span><span class="p">),</span>
    610. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">h_min</span><span class="p">))</span>
    611. <span class="n">cover_masks</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">cover_masks</span><span class="p">)</span>
    612. <span class="n">coverage</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">count_nonzero</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">cover_masks</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    613. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_figure</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> anchors coverage&#39;</span><span class="p">,</span> <span class="n">figure</span><span class="o">=</span><span class="n">fig</span><span class="p">)</span>
    614. <span class="k">return</span> <span class="n">coverage</span></div>
    615. <div class="viewcode-block" id="get_color_augmentation"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.get_color_augmentation">[docs]</a><span class="k">def</span> <span class="nf">get_color_augmentation</span><span class="p">(</span><span class="n">rand_augment_config_string</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">color_jitter</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">crop_size</span><span class="o">=</span><span class="mi">224</span><span class="p">,</span> <span class="n">img_mean</span><span class="o">=</span><span class="p">[</span><span class="mf">0.485</span><span class="p">,</span> <span class="mf">0.456</span><span class="p">,</span> <span class="mf">0.406</span><span class="p">]):</span>
    616. <span class="sd">&quot;&quot;&quot;</span>
    617. <span class="sd"> Returns color augmentation class. As these augmentation cannot work on top one another, only one is returned</span>
    618. <span class="sd"> according to rand_augment_config_string</span>
    619. <span class="sd"> :param rand_augment_config_string: string which defines the auto augment configurations.</span>
    620. <span class="sd"> If none, color jitter will be returned. For possibile values see auto_augment.py</span>
    621. <span class="sd"> :param color_jitter: tuple for color jitter value.</span>
    622. <span class="sd"> :param crop_size: relevant only for auto augment</span>
    623. <span class="sd"> :param img_mean: relevant only for auto augment</span>
    624. <span class="sd"> :return: RandAugment transform or ColorJitter</span>
    625. <span class="sd"> &quot;&quot;&quot;</span>
    626. <span class="k">if</span> <span class="n">rand_augment_config_string</span><span class="p">:</span>
    627. <span class="n">auto_augment_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">translate_const</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><span class="n">crop_size</span> <span class="o">*</span> <span class="mf">0.45</span><span class="p">),</span>
    628. <span class="n">img_mean</span><span class="o">=</span><span class="nb">tuple</span><span class="p">([</span><span class="nb">min</span><span class="p">(</span><span class="mi">255</span><span class="p">,</span> <span class="nb">round</span><span class="p">(</span><span class="mi">255</span> <span class="o">*</span> <span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">img_mean</span><span class="p">]))</span>
    629. <span class="n">color_augmentation</span> <span class="o">=</span> <span class="n">rand_augment_transform</span><span class="p">(</span><span class="n">rand_augment_config_string</span><span class="p">,</span> <span class="n">auto_augment_params</span><span class="p">)</span>
    630. <span class="k">else</span><span class="p">:</span> <span class="c1"># RandAugment includes colorjitter like augmentations, both cannot be applied together.</span>
    631. <span class="n">color_augmentation</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ColorJitter</span><span class="p">(</span><span class="o">*</span><span class="n">color_jitter</span><span class="p">)</span>
    632. <span class="k">return</span> <span class="n">color_augmentation</span></div>
    633. <div class="viewcode-block" id="worker_init_reset_seed"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.worker_init_reset_seed">[docs]</a><span class="k">def</span> <span class="nf">worker_init_reset_seed</span><span class="p">(</span><span class="n">worker_id</span><span class="p">):</span>
    634. <span class="sd">&quot;&quot;&quot;</span>
    635. <span class="sd"> Make sure each process has different random seed, especially for &#39;fork&#39; method.</span>
    636. <span class="sd"> Check https://github.com/pytorch/pytorch/issues/63311 for more details.</span>
    637. <span class="sd"> :param worker_id: placeholder (needs to be passed to DataLoader init).</span>
    638. <span class="sd"> &quot;&quot;&quot;</span>
    639. <span class="n">seed</span> <span class="o">=</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">()</span><span class="o">.</span><span class="n">int</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">32</span>
    640. <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
    641. <span class="n">torch</span><span class="o">.</span><span class="n">set_rng_state</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="o">.</span><span class="n">get_state</span><span class="p">())</span>
    642. <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span></div>
    643. </pre></div>
    644. </div>
    645. </div>
    646. <footer>
    647. <hr/>
    648. <div role="contentinfo">
    649. <p>&#169; Copyright 2021, SuperGradients team.</p>
    650. </div>
    651. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    652. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    653. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    654. </footer>
    655. </div>
    656. </div>
    657. </section>
    658. </div>
    659. <script>
    660. jQuery(function () {
    661. SphinxRtdTheme.Navigation.enable(true);
    662. });
    663. </script>
    664. </body>
    665. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.detection_datasets.coco_detection &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.detection_datasets.coco_detection &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -87,106 +89,97 @@
                  
                  
       <h1>Source code for super_gradients.training.datasets.detection_datasets.coco_detection</h1><div class="highlight"><pre>
       <h1>Source code for super_gradients.training.datasets.detection_datasets.coco_detection</h1><div class="highlight"><pre>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
    -<span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    -<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
    -<span class="kn">import</span> <span class="nn">random</span>
     <span class="kn">import</span> <span class="nn">cv2</span>
     <span class="kn">import</span> <span class="nn">cv2</span>
    +
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    -<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
    -<span class="kn">from</span> <span class="nn">multiprocessing.pool</span> <span class="kn">import</span> <span class="n">ThreadPool</span>
    +<span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
    +
    +<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.detection_dataset</span> <span class="kn">import</span> <span class="n">DetectionDataset</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_conf</span> <span class="kn">import</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span>
     
     
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="COCODetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">COCODetectionDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
    -    <span class="sd">&quot;&quot;&quot;Detection dataset COCO implementation&quot;&quot;&quot;</span>
    +<div class="viewcode-block" id="COCODetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.COCODetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">COCODetectionDataset</span><span class="p">(</span><span class="n">DetectionDataset</span><span class="p">):</span>
    +    <span class="sd">&quot;&quot;&quot;Dataset for COCO object detection.&quot;&quot;&quot;</span>
     
     
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    -            <span class="bp">self</span><span class="p">,</span> <span class="n">img_size</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
    -            <span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -            <span class="n">json_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;instances_train2017.json&quot;</span><span class="p">,</span>
    -            <span class="n">name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;images/train2017&quot;</span><span class="p">,</span>
    -            <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -            <span class="n">cache_dir_path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -            <span class="n">tight_box_rotation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -            <span class="n">transforms</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[],</span>
    -            <span class="n">with_crowd</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">json_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;instances_train2017.json&quot;</span><span class="p">,</span>
    +        <span class="n">subdir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;images/train2017&quot;</span><span class="p">,</span>
    +        <span class="n">tight_box_rotation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">with_crowd</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    +        <span class="o">*</span><span class="n">args</span><span class="p">,</span>
    +        <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
         <span class="p">):</span>
         <span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        :param img_size: tuple, Image size (when loaded, before transforms)</span>
    -<span class="sd">        :param data_dir: str, root path to coco data.</span>
    -<span class="sd">        :param json_file: str, path to coco json file, that resides in data_dir/annotations/json_file.</span>
    -<span class="sd">        :param name: str, sub directory of data_dir containing the data.</span>
    -<span class="sd">        :param cache: bool, whether to cache images</span>
    -<span class="sd">        :param cache_dir_path: str, path to a directory that will be used for caching (with memmap).</span>
    -<span class="sd">        :param tight_box_rotation: bool, whether to use of segmentation maps convex hull</span>
    -<span class="sd">         as target_seg (see load_sample docs).</span>
    -<span class="sd">        :param transforms: list of transforms to apply sequentially on sample in __getitem__</span>
    +<span class="sd">        :param json_file:           Name of the coco json file, that resides in data_dir/annotations/json_file.</span>
    +<span class="sd">        :param subdir:              Sub directory of data_dir containing the data.</span>
    +<span class="sd">        :param tight_box_rotation:  bool, whether to use of segmentation maps convex hull as target_seg</span>
    +<span class="sd">                                    (check get_sample docs).</span>
     <span class="sd">        :param with_crowd: Add the crowd groundtruths to __getitem__</span>
     <span class="sd">        :param with_crowd: Add the crowd groundtruths to __getitem__</span>
    +
    +<span class="sd">        kwargs:</span>
    +<span class="sd">            all_classes_list: all classes list, default is COCO_DETECTION_CLASSES_LIST.</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="o">=</span> <span class="kc">None</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="n">data_dir</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">img_size</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="n">data_dir</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">subdir</span> <span class="o">=</span> <span class="n">subdir</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span> <span class="o">=</span> <span class="n">json_file</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span> <span class="o">=</span> <span class="n">json_file</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="o">=</span> <span class="n">tight_box_rotation</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span> <span class="o">=</span> <span class="n">with_crowd</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span> <span class="o">=</span> <span class="n">with_crowd</span>
     
     
    +        <span class="n">target_fields</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">,</span> <span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span> <span class="k">else</span> <span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
    +        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;target_fields&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">target_fields</span>
    +        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;output_fields&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">,</span> <span class="o">*</span><span class="n">target_fields</span><span class="p">]</span>
    +        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;original_target_format&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span>
    +        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;all_classes_list&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;all_classes_list&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    +
    +    <span class="k">def</span> <span class="nf">_setup_data_source</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
    +        <span class="sd">&quot;&quot;&quot;Initialize img_and_target_path_list and warn if label file is missing</span>
    +
    +<span class="sd">        :return: List of tuples made of (img_path,target_path)</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">coco</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_init_coco</span><span class="p">()</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getCatIds</span><span class="p">())</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">([</span><span class="n">category</span><span class="p">[</span><span class="s2">&quot;name&quot;</span><span class="p">]</span> <span class="k">for</span> <span class="n">category</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadC
    +        <span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getImgIds</span><span class="p">()</span>
    +        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span><span class="p">)</span>
    +
    +    <span class="k">def</span> <span class="nf">_init_coco</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">COCO</span><span class="p">:</span>
             <span class="n">annotation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">&quot;annotations&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span><span c
             <span class="n">annotation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">&quot;annotations&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span><span c
             <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">):</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">):</span>
                 <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Could not find annotation file under &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">))</span>
                 <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Could not find annotation file under &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">))</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">coco</span> <span class="o">=</span> <span class="n">COCO</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">)</span>  <span class="c1"># duplicate</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="o">=</span> <span class="n">tight_box_rotation</span>
    -        <span class="n">remove_useless_info</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getImgIds</span><span class="p">()</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getCatIds</span><span class="p">())</span>
    -        <span class="n">cats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadCats</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getCatIds</span><span class="p">())</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">_classes</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">c</span><span class="p">[</span><span class="s2">&quot;name&quot;</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">cats</span><span class="p">])</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_coco_annotations</span><span class="p">()</span>
    -        <span class="k">if</span> <span class="n">cache</span><span class="p">:</span>  <span class="c1"># cache after merged</span>
    -            <span class="k">if</span> <span class="n">cache_dir_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">cache_dir_path</span><span class="p">):</span>
    -                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must pass valid path through cache_dir_path when caching. Got &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">cache_dir_path</span><span class="p">))</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">cache_dir_path</span> <span class="o">=</span> <span class="n">cache_dir_path</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">_cache_images</span><span class="p">()</span>
    -
    -        <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span>
    -
    -    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span>
    -        <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_sample</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span>
    -        <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_transforms</span><span class="p">(</span><span class="n">sample</span><span class="p">)</span>
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span><span class="p">:</span>
    -            <span class="k">return</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;info&quot;</s
    -        <span class="k">else</span><span class="p">:</span>
    -            <span class="k">return</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;info&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;id&quot;</span><span 
    -
    -    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    -        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">)</span>
    -
    -    <span class="k">def</span> <span class="nf">_load_coco_annotations</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    -        <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_load_anno_from_ids</span><span class="p">(</span><span class="n">_ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">_ids</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">,</span> <span class="n">desc</span><spa
    -
    -    <span class="k">def</span> <span class="nf">_load_anno_from_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">id_</span><span class="p">):</span>
    +
    +        <span class="n">coco</span> <span class="o">=</span> <span class="n">COCO</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">)</span>
    +        <span class="n">remove_useless_info</span><span class="p">(</span><span class="n">coco</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">)</span>
    +        <span class="k">return</span> <span class="n">coco</span>
    +
    +    <span class="k">def</span> <span class="nf">_load_annotation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample_id</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        Load relevant information of a specific image</span>
    -
    -<span class="sd">        :param id_: image id</span>
    -<span class="sd">        :return res:            Target Bboxes (detection)</span>
    -<span class="sd">        :return res_crowd:      Crowd target Bboxes (detection)</span>
    -<span class="sd">        :return res_seg:        Segmentation</span>
    -<span class="sd">        :return img_info:       Image (height, width)</span>
    -<span class="sd">        :return resized_info:   Resides image (height, width)</span>
    -<span class="sd">        :return img_path:       Path to the associated image</span>
    +<span class="sd">        Load relevant information of a specific image.</span>
    +
    +<span class="sd">        :param sample_id:               Sample_id in the dataset</span>
    +<span class="sd">        :return target:                 Target Bboxes (detection) in XYXY_LABEL format</span>
    +<span class="sd">        :return crowd_target:           Crowd target Bboxes (detection) in XYXY_LABEL format</span>
    +<span class="sd">        :return target_segmentation:    Segmentation</span>
    +<span class="sd">        :return initial_img_shape:      Image (height, width)</span>
    +<span class="sd">        :return resized_img_shape:      Resides image (height, width)</span>
    +<span class="sd">        :return img_path:               Path to the associated image</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">im_ann</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadImgs</span><span class="p">(</span><span class="n">id_</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    -        <span class="n">width</span> <span class="o">=</span> <span class="n">im_ann</span><span class="p">[</span><span class="s2">&quot;width&quot;</span><span class="p">]</span>
    -        <span class="n">height</span> <span class="o">=</span> <span class="n">im_ann</span><span class="p">[</span><span class="s2">&quot;height&quot;</span><span class="p">]</span>
    -        <span class="n">anno_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getAnnIds</span><span class="p">(</span><span class="n">imgIds</span><span class="o">=</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">id_</span><span class="p">)])</span>
    -        <span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadAnns</span><span class="p">(</span><span class="n">anno_ids</span><span class="p">)</span>
    +
    +        <span class="n">img_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
    +
    +        <span class="n">img_metadata</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadImgs</span><span class="p">(</span><span class="n">img_id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    +        <span class="n">width</span> <span class="o">=</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s2">&quot;width&quot;</span><span class="p">]</span>
    +        <span class="n">height</span> <span class="o">=</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s2">&quot;height&quot;</span><span class="p">]</span>
    +
    +        <span class="n">img_annotation_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getAnnIds</span><span class="p">(</span><span class="n">imgIds</span><span class="o">=</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">img_id</span><span class="p">)])</span>
    +        <span class="n">img_annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadAnns</span><span class="p">(</span><span class="n">img_annotation_ids</span><span class="p">)</span>
     
     
             <span class="n">cleaned_annotations</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">cleaned_annotations</span> <span class="o">=</span> <span class="p">[]</span>
    -        <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">annotations</span><span class="p">:</span>
    +        <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">img_annotations</span><span class="p">:</span>
                 <span class="n">x1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>
                 <span class="n">x1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>
                 <span class="n">y1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">]))</span>
                 <span class="n">y1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">]))</span>
                 <span class="n">x2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">width</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bb
                 <span class="n">x2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">width</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bb
    @@ -197,14 +190,14 @@
     
     
             <span class="n">non_crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span>
             <span class="n">non_crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span>
     
     
    -        <span class="n">res</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
    +        <span class="n">target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
             <span class="n">num_seg_values</span> <span class="o">=</span> <span class="mi">98</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="k">else</span> <span class="mi">0</span>
             <span class="n">num_seg_values</span> <span class="o">=</span> <span class="mi">98</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="k">else</span> <span class="mi">0</span>
    -        <span class="n">res_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="n">num_seg_values</span><span class="p">))</span>
    -        <span class="n">res_seg</span><span class="o">.</span><span class="n">fill</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">)</span>
    +        <span class="n">target_segmentation</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="n">num_seg_values</span><span class="p">))</span>
    +        <span class="n">target_segmentation</span><span class="o">.</span><span class="n">fill</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">)</span>
             <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">):</span>
             <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">):</span>
                 <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
                 <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
    -            <span class="n">res</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
    -            <span class="n">res</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
    +            <span class="n">target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
    +            <span class="n">target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">:</span>
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">:</span>
                     <span class="n">seg_points</span> <span class="o">=</span> <span class="p">[</span><span class="n">j</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">annotation</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span clas
                     <span class="n">seg_points</span> <span class="o">=</span> <span class="p">[</span><span class="n">j</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">annotation</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span clas
                     <span class="k">if</span> <span class="n">seg_points</span><span class="p">:</span>
                     <span class="k">if</span> <span class="n">seg_points</span><span class="p">:</span>
    @@ -212,181 +205,41 @@
                         <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">convexHull</span><span class="p">(</span><span class="n">seg_points_c</span><span class="p">)</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
                         <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">convexHull</span><span class="p">(</span><span class="n">seg_points_c</span><span class="p">)</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
                     <span class="k">else</span><span class="p">:</span>
                     <span class="k">else</span><span class="p">:</span>
                         <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="p">[]</span>
                         <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="p">[]</span>
    -                <span class="n">res_seg</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="p">:</span><span class="nb">len</span><span class="p">(</span><span class="n">seg_points_convex</span><span class="p">)]</span> <span class="o">=</span> <span class="n">seg_points_convex</span>
    +                <span class="n">target_segmentation</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">seg_points_convex</span><span class="p">)]</span> <span class="o">=</span> <span class="n">seg_points_convex</span>
     
     
             <span class="n">crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">]</span>
             <span class="n">crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">]</span>
     
     
    -        <span class="n">res_crowd</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
    +        <span class="n">crowd_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
             <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">):</span>
             <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">):</span>
                 <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
                 <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
    -            <span class="n">res_crowd</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
    -            <span class="n">res_crowd</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
    +            <span class="n">crowd_target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
    +            <span class="n">crowd_target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
     
     
             <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</s
             <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</s
    -        <span class="n">res</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
    -        <span class="n">res_crowd</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
    -        <span class="n">res_seg</span> <span class="o">*=</span> <span class="n">r</span>
    -
    -        <span class="n">img_info</span> <span class="o">=</span> <span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span>
    -        <span class="n">resized_info</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
    -
    -        <span class="n">file_name</span> <span class="o">=</span> <span class="p">(</span>
    -            <span class="n">im_ann</span><span class="p">[</span><span class="s2">&quot;file_name&quot;</span><span class="p">]</span>
    -            <span class="k">if</span> <span class="s2">&quot;file_name&quot;</span> <span class="ow">in</span> <span class="n">im_ann</span>
    -            <span class="k">else</span> <span class="s2">&quot;</span><span class="si">{:012}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">id_</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;.jpg&quot;</span>
    -        <span class="p">)</span>
    -        <span class="n">img_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">file_name</span><span class="p">)</span>
    -        <span class="k">return</span> <span class="n">res</span><span class="p">,</span> <span class="n">res_crowd</span><span class="p">,</span> <span class="n">res_seg</span><span class="p">,</span> <span class="n">img_info</span><span class="p">,</span> <span class="n">resized_info</span><span class="p">,</span> <span class="n">img_path</span>
    -
    -    <span class="k">def</span> <span class="nf">_cache_images</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    -        <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
    -            <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
    -            <span class="s2">&quot;You are using cached images in RAM to accelerate training.</span><span class="se">\n</span><span class="s2">&quot;</span>
    -            <span class="s2">&quot;This requires large system RAM.</span><span class="se">\n</span><span class="s2">&quot;</span>
    -            <span class="s2">&quot;Make sure you have 200G+ RAM and 136G available disk space for training COCO.</span><span class="se">\n</span><span class="s2">&quot;</span>
    -            <span class="s2">&quot;********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
    -        <span class="p">)</span>
    -        <span class="n">max_h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    -        <span class="n">max_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    -        <span class="n">cache_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cache_dir_path</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;img_resized_cache_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">n
    -        <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">cache_file</span><span class="p">):</span>
    -            <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
    -                <span class="s2">&quot;Caching images for the first time. This might take about 20 minutes for COCO&quot;</span>
    -            <span class="p">)</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span>
    -                <span class="n">cache_file</span><span class="p">,</span>
    -                <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">),</span> <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span>
    -                <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
    -                <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;w+&quot;</span><span class="p">,</span>
    -            <span class="p">)</span>
    -
    -            <span class="n">NUM_THREADs</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">os</span><span class="o">.</span><span class="n">cpu_count</span><span class="p">())</span>
    -            <span class="n">loaded_images</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">NUM_THREADs</span><span class="p">)</span><span class="o">.</span><span class="n">imap</span><span class="p">(</span>
    -                <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_resized_img</span><span class="p">(</span><span class="n">x</span><span class="p">),</span>
    -                <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">)),</span>
    -            <span class="p">)</span>
    -            <span class="n">pbar</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">loaded_images</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">))</span>
    -            <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">out</span> <span class="ow">in</span> <span class="n">pbar</span><span class="p">:</span>
    -                <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span><span class="p">[</span><span class="n">k</span><span class="p">][:</span> <span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">:</span> <span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span
    -            <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
    -            <span class="n">pbar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    -        <span class="k">else</span><span class="p">:</span>
    -            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
    -                <span class="s2">&quot;You are using cached imgs! Make sure your dataset is not changed!!</span><span class="se">\n</span><span class="s2">&quot;</span>
    -                <span class="s2">&quot;Everytime the self.input_size is changed in your exp file, you need to delete</span><span class="se">\n</span><span class="s2">&quot;</span>
    -                <span class="s2">&quot;the cached data and re-generate them.</span><span class="se">\n</span><span class="s2">&quot;</span>
    -            <span class="p">)</span>
    -
    -        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Loading cached imgs...&quot;</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span>
    -            <span class="n">cache_file</span><span class="p">,</span>
    -            <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">),</span> <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span>
    -            <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
    -            <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;r+&quot;</span><span class="p">,</span>
    -        <span class="p">)</span>
    -
    -<div class="viewcode-block" id="COCODetectionDataset.load_resized_img"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.load_resized_img">[docs]</a>    <span class="k">def</span> <span class="nf">load_resized_img</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
    -        <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        Loads image at index, and resizes it to self.input_dim</span>
    -
    -<span class="sd">        :param index: index to load the image from</span>
    -<span class="sd">        :return: resized_img</span>
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_image</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
    -        <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><s
    -        <span class="n">resized_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
    -            <span class="n">img</span><span class="p">,</span>
    -            <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span c
    -            <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">,</span>
    -        <span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
    -        <span class="k">return</span> <span class="n">resized_img</span></div>
    -
    -<div class="viewcode-block" id="COCODetectionDataset.load_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.load_sample">[docs]</a>    <span class="k">def</span> <span class="nf">load_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
    -        <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        Loads sample at self.ids[index] as dictionary that holds:</span>
    -<span class="sd">            &quot;image&quot;: Image resized to self.input_dim</span>
    -<span class="sd">            &quot;target&quot;: Detection ground truth, np.array shaped (num_targets, 5), format is [class,x1,y1,x2,y2] with</span>
    -<span class="sd">                image coordinates.</span>
    -<span class="sd">            &quot;target_seg&quot;: Segmentation map convex hull derived detection target.</span>
    -<span class="sd">            &quot;info&quot;: Original shape (height,width).</span>
    -<span class="sd">            &quot;id&quot;: COCO image id</span>
    -
    -<span class="sd">        :param index: Sample index</span>
    -<span class="sd">        :return: sample as described above</span>
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">id_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    -        <span class="n">res</span><span class="p">,</span> <span class="n">res_crowd</span><span class="p">,</span> <span class="n">res_seg</span><span class="p">,</span> <span class="n">img_info</span><span class="p">,</span> <span class="n">resized_info</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="n">pad_img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    -            <span class="n">img</span> <span class="o">=</span> <span class="n">pad_img</span><span class="p">[:</span> <span class="n">resized_info</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">:</span> <span class="n">resized_info</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">:]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    -        <span class="k">else</span><span class="p">:</span>
    -            <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_resized_img</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
    -
    -        <span class="n">sample</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">img</span><span class="p">,</span> <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">res</span><span class="o">.</span><span class="n">copy</span><span class="p">(),</span> <span class="s2">&quot;target_seg&quot;</span><span class="p">:</span> <span class="n">res_seg</span><span class="p">,</span
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span><span class="p">:</span>
    -            <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">res_crowd</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    -        <span class="k">return</span> <span class="n">sample</span></div>
    -
    -<div class="viewcode-block" id="COCODetectionDataset.load_image"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.load_image">[docs]</a>    <span class="k">def</span> <span class="nf">load_image</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
    -        <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        Loads image at index with its original resolution</span>
    -<span class="sd">        :param index: index in self.annotations</span>
    -<span class="sd">        :return: image (np.array)</span>
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">file_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">][</span><span class="mi">5</span><span class="p">]</span>
    -
    -        <span class="n">img_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">file_name</span><span class="p">)</span>
    -
    -        <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img_file</span><span class="p">)</span>
    -        <span class="k">assert</span> <span class="n">img</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
    -
    -        <span class="k">return</span> <span class="n">img</span></div>
    -
    -    <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    -        <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span>
    -
    -    <span class="k">def</span> <span class="nf">_load_anno</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
    -        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
    -
    -    <span class="k">def</span> <span class="nf">_get_random_non_empty_target_idx</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    -        <span class="n">target</span> <span class="o">=</span> <span class="p">[]</span>
    -        <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">target</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    -            <span class="n">idx</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
    -            <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_anno</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span>
    -        <span class="k">return</span> <span class="n">idx</span>
    -
    -    <span class="k">def</span> <span class="nf">_load_random_samples</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">count</span><span class="p">,</span> <span class="n">non_empty_targets_only</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    -        <span class="n">inds</span> <span class="o">=</span> <span class="p">[</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">_get_random_non_empty_target_idx</span><span class="p">()</span> <span class="k">if</span> <span class="n">non_empty_targets_only</span> <span class="k">else</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.<
    -            <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">count</span><span class="p">)]</span>
    -        <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">load_sample</span><span class="p">(</span><span class="n">ind</span><span class="p">)</span> <span class="k">for</span> <span class="n">ind</span> <span class="ow">in</span> <span class="n">inds</span><span class="p">]</span>
    -
    -    <span class="k">def</span> <span class="nf">_load_additional_inputs_for_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">,</span> <span class="n">transform</span><span class="p">):</span>
    -        <span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span>
    -                                                                                 <span class="s2">&quot;additional_samples_count&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span>
    -        <span class="n">non_empty_targets</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">non_empty_targets</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;non_empty_targets&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
    -        <span class="n">additional_samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_random_samples</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="p">,</span> <span class="n">non_empty_targets</span><span class="p">)</span>
    -        <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">additional_samples</span>
    -
    -<div class="viewcode-block" id="COCODetectionDataset.apply_transforms"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.apply_transforms">[docs]</a>    <span class="k">def</span> <span class="nf">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class
    -        <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        Applies self.transforms sequentially to sample</span>
    -
    -<span class="sd">        If a transforms has the attribute &#39;additional_samples_count&#39;, additional samples will be loaded and stored in</span>
    -<span class="sd">         sample[&quot;additional_samples&quot;] prior to applying it. Combining with the attribute &quot;non_empty_targets&quot; will load</span>
    -<span class="sd">         only additional samples with objects in them.</span>
    -
    -<span class="sd">        :param sample: Sample to apply the transforms on to (loaded with self.load_sample)</span>
    -<span class="sd">        :return: Transformed sample</span>
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">for</span> <span class="n">transform</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span><span class="p">:</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">_load_additional_inputs_for_transform</span><span class="p">(</span><span class="n">sample</span><span class="p">,</span> <span class="n">transform</span><span class="p">)</span>
    -            <span class="n">sample</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">sample</span><span class="p">)</span>
    -
    -        <span class="k">return</span> <span class="n">sample</span></div></div>
    -
    -
    -<div class="viewcode-block" id="remove_useless_info"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.remove_useless_info">[docs]</a><span class="k">def</span> <span class="nf">remove_useless_info</span><span class="p">(</span><span class="n">coco</span><span class="p">,</span> <span class="n">use_seg_info</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    +        <span class="n">target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
    +        <span class="n">crowd_target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
    +        <span class="n">target_segmentation</span> <span class="o">*=</span> <span class="n">r</span>
    +
    +        <span class="n">initial_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span>
    +        <span class="n">resized_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
    +
    +        <span class="n">file_name</span> <span class="o">=</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s2">&quot;file_name&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="s2">&quot;file_name&quot;</span> <span class="ow">in</span> <span class="n">img_metadata</span> <span class="k">else</span> <span class="s2">&quot;</span><span class="si">{:012}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</spa
    +        <span class="n">img_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">subdir</span><span class="p">,</span> <span class="n">file_name</span><span class="p">)</span>
    +        <span class="n">img_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
    +
    +        <span class="n">annotation</span> <span class="o">=</span> <span class="p">{</span>
    +            <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">target</span><span class="p">,</span>
    +            <span class="s2">&quot;crowd_target&quot;</span><span class="p">:</span> <span class="n">crowd_target</span><span class="p">,</span>
    +            <span class="s2">&quot;target_segmentation&quot;</span><span class="p">:</span> <span class="n">target_segmentation</span><span class="p">,</span>
    +            <span class="s2">&quot;initial_img_shape&quot;</span><span class="p">:</span> <span class="n">initial_img_shape</span><span class="p">,</span>
    +            <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">:</span> <span class="n">resized_img_shape</span><span class="p">,</span>
    +            <span class="s2">&quot;img_path&quot;</span><span class="p">:</span> <span class="n">img_path</span><span class="p">,</span>
    +            <span class="s2">&quot;id&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">img_id</span><span class="p">]),</span>
    +        <span class="p">}</span>
    +        <span class="k">return</span> <span class="n">annotation</span></div>
    +
    +
    +<span class="k">def</span> <span class="nf">remove_useless_info</span><span class="p">(</span><span class="n">coco</span><span class="p">,</span> <span class="n">use_seg_info</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Remove useless info in coco dataset. COCO object is modified inplace.</span>
     <span class="sd">    Remove useless info in coco dataset. COCO object is modified inplace.</span>
     <span class="sd">    This function is mainly used for saving memory (save about 30% mem).</span>
     <span class="sd">    This function is mainly used for saving memory (save about 30% mem).</span>
    @@ -402,7 +255,7 @@
                 <span class="n">img</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;flickr_url&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
                 <span class="n">img</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;flickr_url&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
             <span class="k">if</span> <span class="s2">&quot;annotations&quot;</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">use_seg_info</span><span class="p">:</span>
             <span class="k">if</span> <span class="s2">&quot;annotations&quot;</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">use_seg_info</span><span class="p">:</span>
                 <span class="k">for</span> <span class="n">anno</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;annotations&quot;</span><span class="p">]:</span>
                 <span class="k">for</span> <span class="n">anno</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;annotations&quot;</span><span class="p">]:</span>
    -                <span class="n">anno</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
    +                <span class="n">anno</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -432,4 +285,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.detection_datasets.detection_dataset &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.detection_datasets.detection_dataset &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -93,21 +95,25 @@
     <span class="kn">import</span> <span class="nn">cv2</span>
     <span class="kn">import</span> <span class="nn">cv2</span>
     <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
     <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
     <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
     <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
    +<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
     <span class="kn">import</span> <span class="nn">hashlib</span>
     <span class="kn">import</span> <span class="nn">hashlib</span>
     
     
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
     <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
     
     
    +<span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">get_cls_posx_in_target</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">get_cls_posx_in_target</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionTransform</span><span class="p">,</span> <span class="n">DetectionTargetsFormatTransform</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionTransform</span><span class="p">,</span> <span class="n">DetectionTargetsFormatTransform</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">EmptyDatasetException</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">EmptyDatasetException</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.common.factories.list_factory</span> <span class="kn">import</span> <span class="n">ListFactory</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
     
     
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="DetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">DetectionDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
    +<div class="viewcode-block" id="DetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">DetectionDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
         <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
     
     
     <span class="sd">    This is a boilerplate class to facilitate the implementation of datasets.</span>
     <span class="sd">    This is a boilerplate class to facilitate the implementation of datasets.</span>
    @@ -146,20 +152,21 @@
     <span class="sd">                                &gt; Therefore, we also have len(self) = 100</span>
     <span class="sd">                                &gt; Therefore, we also have len(self) = 100</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
    +    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">ListFactory</span><span class="p">(</span><span class="n">TransformsFactory</span><span class="p">()))</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    -            <span class="bp">self</span><span class="p">,</span>
    -            <span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    -            <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
    -            <span class="n">original_target_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span><span class="p">,</span>
    -            <span class="n">max_num_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -            <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -            <span class="n">cache_path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -            <span class="n">transforms</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">DetectionTransform</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span>
    -            <span class="n">all_classes_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -            <span class="n">class_inclusion_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -            <span class="n">ignore_empty_annotations</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    -            <span class="n">target_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -            <span class="n">output_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    +        <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
    +        <span class="n">original_target_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span><span class="p">,</span>
    +        <span class="n">max_num_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">cache_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">transforms</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">DetectionTransform</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span>
    +        <span class="n">all_classes_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">class_inclusion_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">ignore_empty_annotations</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    +        <span class="n">target_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">output_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
         <span class="p">):</span>
         <span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
             <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
     
     
    @@ -169,7 +176,7 @@
     <span class="sd">                                        differ based on transforms.</span>
     <span class="sd">                                        differ based on transforms.</span>
     <span class="sd">        :param max_num_samples:         If not None, set the maximum size of the dataset by only indexing the first n annotations/images.</span>
     <span class="sd">        :param max_num_samples:         If not None, set the maximum size of the dataset by only indexing the first n annotations/images.</span>
     <span class="sd">        :param cache:                   Whether to cache images or not.</span>
     <span class="sd">        :param cache:                   Whether to cache images or not.</span>
    -<span class="sd">        :param cache_path:              Path to the directory where cached images will be stored in an optimized format.</span>
    +<span class="sd">        :param cache_dir:              Path to the directory where cached images will be stored in an optimized format.</span>
     <span class="sd">        :param transforms:              List of transforms to apply sequentially on sample.</span>
     <span class="sd">        :param transforms:              List of transforms to apply sequentially on sample.</span>
     <span class="sd">        :param all_classes_list:        All the class names.</span>
     <span class="sd">        :param all_classes_list:        All the class names.</span>
     <span class="sd">        :param class_inclusion_list:    If not None,every class not included will be ignored.</span>
     <span class="sd">        :param class_inclusion_list:    If not None,every class not included will be ignored.</span>
    @@ -208,11 +215,12 @@
             <span class="k">if</span> <span class="s2">&quot;target&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_fields</span><span class="p">:</span>
             <span class="k">if</span> <span class="s2">&quot;target&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_fields</span><span class="p">:</span>
                 <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s1">&#39;&quot;target&quot; is expected to be in the fields to subclass but it was not included&#39;</span><span class="p">)</span>
                 <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s1">&#39;&quot;target&quot; is expected to be in the fields to subclass but it was not included&#39;</span><span class="p">)</span>
     
     
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_required_annotation_fields</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;target&quot;</span><span class="p">,</span> <span class="s2">&quot;img_path&quot;</span><span class="p">,</span> <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">}</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_annotations</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_annotations</span><span class="p">()</span>
     
     
             <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="n">cache</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="n">cache</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">cache_path</span> <span class="o">=</span> <span class="n">cache_path</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_images</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="k">else</span> <span class="kc">None</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">cache_dir</span> <span class="o">=</span> <span class="n">cache_dir</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs_padded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_images</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="k">else</span> <span class="kc">None</span>
     
     
             <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span>
     
     
    @@ -231,7 +239,7 @@
     <span class="sd">        Please note that the targets should be resized according to self.input_dim!</span>
     <span class="sd">        Please note that the targets should be resized according to self.input_dim!</span>
     
     
     <span class="sd">        :param sample_id:   Id of the sample to load annotations from.</span>
     <span class="sd">        :param sample_id:   Id of the sample to load annotations from.</span>
    -<span class="sd">        :return:            Annotation, a dict with any field but has to include at least &quot;target&quot; and &quot;img_path&quot;.</span>
    +<span class="sd">        :return:            Annotation, a dict with any field but has to include at least the fields specified in self._required_annotation_fields.</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="k">raise</span> <span class="ne">NotImplementedError</span>
             <span class="k">raise</span> <span class="ne">NotImplementedError</span>
     
     
    @@ -246,8 +254,10 @@
                     <span class="k">break</span>
                     <span class="k">break</span>
     
     
                 <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_annotation</span><span class="p">(</span><span class="n">img_id</span><span class="p">)</span>
                 <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_annotation</span><span class="p">(</span><span class="n">img_id</span><span class="p">)</span>
    -            <span class="k">if</span> <span class="s2">&quot;target&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">img_annotation</span> <span class="ow">or</span> <span class="s2">&quot;img_path&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">img_annotation</span><span class="p">:</span>
    -                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s1">&#39;_load_annotation is expected to return at least the field &quot;target&quot; and &quot;img_path&quot;&#39;</span><span class="p">)</span>
    +            <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_required_annotation_fields</span><span class="o">.</span><span class="n">issubset</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">img_annotation</span><span class="o">.</span><span class="n">keys</span><span class="p">())):</span>
    +                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span>
    +                    <span class="sa">f</span><span class="s2">&quot;_load_annotation is expected to return at least the fields </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_required_annotation_fields</span><span class="si">}</span><span class="s2"> &quot;</span> <span class="sa">f</span><span class="s2">&quot;but got </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">img_annotation</span><span
    +                <span class="p">)</span>
     
     
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                     <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sub_class_annotation</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
                     <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sub_class_annotation</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
    @@ -258,8 +268,9 @@
                 <span class="n">annotations</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
                 <span class="n">annotations</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
     
     
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">annotations</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">annotations</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="n">EmptyDatasetException</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Out of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">n_available_samples</span><span class="si">}</span><span class="s2"> images, not a single one was found with&quot;</span>
    -                                        <span class="sa">f</span><span class="s2">&quot;any of these classes: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    +            <span class="k">raise</span> <span class="n">EmptyDatasetException</span><span class="p">(</span>
    +                <span class="sa">f</span><span class="s2">&quot;Out of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">n_available_samples</span><span class="si">}</span><span class="s2"> images, not a single one was found with&quot;</span> <span class="sa">f</span><span class="s2">&quot;any of these classes: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span><span cla
    +            <span class="p">)</span>
             <span class="k">return</span> <span class="n">annotations</span>
             <span class="k">return</span> <span class="n">annotations</span>
     
     
         <span class="k">def</span> <span class="nf">_sub_class_annotation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">annotation</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span>
         <span class="k">def</span> <span class="nf">_sub_class_annotation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">annotation</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span>
    @@ -296,16 +307,17 @@
             <span class="sd">&quot;&quot;&quot;Cache the images. The cached image are stored in a file to be loaded faster mext time.</span>
             <span class="sd">&quot;&quot;&quot;Cache the images. The cached image are stored in a file to be loaded faster mext time.</span>
     <span class="sd">        :return: Cached images</span>
     <span class="sd">        :return: Cached images</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">cache_path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cache_path</span><span class="p">)</span>
    -        <span class="k">if</span> <span class="n">cache_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;You must specify a cache_path if you want to cache your images.&quot;</span>
    -                             <span class="s2">&quot;If you did not mean to use cache, please set cache=False &quot;</span><span class="p">)</span>
    -        <span class="n">cache_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    -
    -        <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
    -                       <span class="s2">&quot;You are using cached images in RAM to accelerate training.</span><span class="se">\n</span><span class="s2">&quot;</span>
    -                       <span class="s2">&quot;This requires large system RAM.</span><span class="se">\n</span><span class="s2">&quot;</span>
    -                       <span class="s2">&quot;********************************************************************************&quot;</span><span class="p">)</span>
    +        <span class="n">cache_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cache_dir</span><span class="p">)</span>
    +        <span class="k">if</span> <span class="n">cache_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;You must specify a cache_dir if you want to cache your images.&quot;</span> <span class="s2">&quot;If you did not mean to use cache, please set cache=False &quot;</span><span class="p">)</span>
    +        <span class="n">cache_dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +
    +        <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
    +            <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
    +            <span class="s2">&quot;You are using cached images in RAM to accelerate training.</span><span class="se">\n</span><span class="s2">&quot;</span>
    +            <span class="s2">&quot;This requires large system RAM.</span><span class="se">\n</span><span class="s2">&quot;</span>
    +            <span class="s2">&quot;********************************************************************************&quot;</span>
    +        <span class="p">)</span>
     
     
             <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
             <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
     
     
    @@ -314,10 +326,10 @@
             <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">:</span>
                 <span class="n">values_to_hash</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">Path</span><
                 <span class="n">values_to_hash</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">Path</span><
                 <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">values_to_hash</span><span class="p">:</span>
                 <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">values_to_hash</span><span class="p">:</span>
    -                <span class="nb">hash</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">&#39;utf-8&#39;</span><span class="p">))</span>
    +                <span class="nb">hash</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s2">&quot;utf-8&quot;</span><span class="p">))</span>
             <span class="n">cache_hash</span> <span class="o">=</span> <span class="nb">hash</span><span class="o">.</span><span class="n">hexdigest</span><span class="p">()</span>
             <span class="n">cache_hash</span> <span class="o">=</span> <span class="nb">hash</span><span class="o">.</span><span class="n">hexdigest</span><span class="p">()</span>
     
     
    -        <span class="n">img_resized_cache_path</span> <span class="o">=</span> <span class="n">cache_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;img_resized_cache_</span><span class="si">{</span><span class="n">cache_hash</span><span class="si">}</span><span class="s2">.array&quot;</span>
    +        <span class="n">img_resized_cache_path</span> <span class="o">=</span> <span class="n">cache_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;img_resized_cache_</span><span class="si">{</span><span class="n">cache_hash</span><span class="si">}</span><span class="s2">.array&quot;</span>
     
     
             <span class="k">if</span> <span class="ow">not</span> <span class="n">img_resized_cache_path</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="n">img_resized_cache_path</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
                 <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Caching images for the first time. Be aware that this will stay in the disk until you delete it yourself.&quot;</span><span class="p">)</span>
                 <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Caching images for the first time. Be aware that this will stay in the disk until you delete it yourself.&quot;</span><span class="p">)</span>
    @@ -325,8 +337,7 @@
                 <span class="n">loaded_images</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">NUM_THREADs</span><span class="p">)</span><span class="o">.</span><span class="n">imap</span><span class="p">(</span><span class="n">func</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="
                 <span class="n">loaded_images</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">NUM_THREADs</span><span class="p">)</span><span class="o">.</span><span class="n">imap</span><span class="p">(</span><span class="n">func</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="
     
     
                 <span class="c1"># Initialize placeholder for images</span>
                 <span class="c1"># Initialize placeholder for images</span>
    -            <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h<
    -                                    <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;w+&quot;</span><span class="p">)</span>
    +            <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h<
     
     
                 <span class="c1"># Store images in the placeholder</span>
                 <span class="c1"># Store images in the placeholder</span>
                 <span class="n">loaded_images_pbar</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">loaded_images</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span>
                 <span class="n">loaded_images_pbar</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">loaded_images</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span>
    @@ -338,8 +349,7 @@
                 <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;You are using cached imgs!&quot;</span><span class="p">)</span>
                 <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;You are using cached imgs!&quot;</span><span class="p">)</span>
     
     
             <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Loading cached imgs...&quot;</span><span class="p">)</span>
             <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Loading cached imgs...&quot;</span><span class="p">)</span>
    -        <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h</spa
    -                                <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;r+&quot;</span><span class="p">)</span>
    +        <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h</spa
             <span class="k">return</span> <span class="n">cached_imgs</span>
             <span class="k">return</span> <span class="n">cached_imgs</span>
     
     
         <span class="k">def</span> <span class="nf">_load_resized_img</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
         <span class="k">def</span> <span class="nf">_load_resized_img</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
    @@ -366,14 +376,13 @@
             <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img_file</span><span class="p">)</span>
             <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img_file</span><span class="p">)</span>
     
     
             <span class="k">if</span> <span class="n">img</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">img</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">img_file</span><span class="si">}</span><span class="s2"> was no found. Please make sure that the dataset was&quot;</span>
    -                                    <span class="sa">f</span><span class="s2">&quot;downloaded and that the path is correct&quot;</span><span class="p">)</span>
    +            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">img_file</span><span class="si">}</span><span class="s2"> was no found. Please make sure that the dataset was&quot;</span> <span class="sa">f</span><span class="s2">&quot;downloaded and that the path is correct&quot;</span><span class="p">)</span>
             <span class="k">return</span> <span class="n">img</span>
             <span class="k">return</span> <span class="n">img</span>
     
     
         <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;Clear the cached images&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;Clear the cached images&quot;&quot;&quot;</span>
    -        <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;cached_imgs&quot;</span><span class="p">):</span>
    -            <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs</span>
    +        <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;cached_imgs_padded&quot;</span><span class="p">):</span>
    +            <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs_padded</span>
     
     
         <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;Get the length of the dataset.&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;Get the length of the dataset.&quot;&quot;&quot;</span>
    @@ -385,34 +394,37 @@
             <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="n">index</span><span class="p">))</span>
             <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="n">index</span><span class="p">))</span>
             <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">:</span>
                 <span class="k">if</span> <span class="n">field</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">sample</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
                 <span class="k">if</span> <span class="n">field</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">sample</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    -                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;The field </span><span class="si">{</span><span class="n">field</span><span class="si">}</span><span class="s1"> must be present in the sample but was not found.&#39;</span>
    -                               <span class="s1">&#39;Please check the output fields of your transforms.&#39;</span><span class="p">)</span>
    +                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;The field </span><span class="si">{</span><span class="n">field</span><span class="si">}</span><span class="s2"> must be present in the sample but was not found.&quot;</span> <span class="s2">&quot;Please check the output fields of your transforms.&quot;</span><span class="p">)</span>
             <span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="n">field</span><span class="p">]</span> <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">)</span>
             <span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="n">field</span><span class="p">]</span> <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="DetectionDataset.get_random_item"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_random_item">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_item</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +<div class="viewcode-block" id="DetectionDataset.get_random_item"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_random_item">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_item</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="k">return</span> <span class="bp">self</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_random_index</span><span class="p">()]</span></div>
             <span class="k">return</span> <span class="bp">self</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_random_index</span><span class="p">()]</span></div>
     
     
    -<div class="viewcode-block" id="DetectionDataset.get_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o
    +<div class="viewcode-block" id="DetectionDataset.get_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="
             <span class="sd">&quot;&quot;&quot;Get raw sample, before any transform (beside subclassing).</span>
             <span class="sd">&quot;&quot;&quot;Get raw sample, before any transform (beside subclassing).</span>
     <span class="sd">        :param index:   Image index</span>
     <span class="sd">        :param index:   Image index</span>
     <span class="sd">        :return:        Sample, i.e. a dictionary including at least &quot;image&quot; and &quot;target&quot;</span>
     <span class="sd">        :return:        Sample, i.e. a dictionary including at least &quot;image&quot; and &quot;target&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_resized_image</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
             <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_resized_image</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
    -        <span class="n">annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    +        <span class="n">annotation</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">])</span>
             <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">annotation</span><span class="p">}</span></div>
             <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">annotation</span><span class="p">}</span></div>
     
     
    -<div class="viewcode-block" id="DetectionDataset.get_resized_image"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_resized_image">[docs]</a>    <span class="k">def</span> <span class="nf">get_resized_image</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)<
    +<div class="viewcode-block" id="DetectionDataset.get_resized_image"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_resized_image">[docs]</a>    <span class="k">def</span> <span class="nf">get_resized_image</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        Get the resized image at a specific sample_id, either from cache or by loading from disk, based on self.cached_imgs</span>
    +<span class="sd">        Get the resized image (i.e. either width or height reaches its input_dim) at a specific sample_id,</span>
    +<span class="sd">        either from cache or by loading from disk, based on self.cached_imgs_padded</span>
     <span class="sd">        :param index:  Image index</span>
     <span class="sd">        :param index:  Image index</span>
     <span class="sd">        :return:       Resized image</span>
     <span class="sd">        :return:       Resized image</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span><span class="p">:</span>
    -            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    +            <span class="n">padded_image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs_padded</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    +            <span class="n">resized_height</span><span class="p">,</span> <span class="n">resized_width</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">][</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">]</span>
    +            <span class="n">resized_image</span> <span class="o">=</span> <span class="n">padded_image</span><span class="p">[:</span><span class="n">resized_height</span><span class="p">,</span> <span class="p">:</span><span class="n">resized_width</span><span class="p">,</span> <span class="p">:]</span>
    +            <span class="k">return</span> <span class="n">resized_image</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="p">(</span><span class="n">index</span><span class="p">)</span></div>
                 <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="p">(</span><span class="n">index</span><span class="p">)</span></div>
     
     
    -<div class="viewcode-block" id="DetectionDataset.apply_transforms"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.apply_transforms">[docs]</a>    <span class="k">def</span> <span class="nf">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</s
    +<div class="viewcode-block" id="DetectionDataset.apply_transforms"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.apply_transforms">[docs]</a>    <span class="k">def</span> <span class="nf">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</sp
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Applies self.transforms sequentially to sample</span>
     <span class="sd">        Applies self.transforms sequentially to sample</span>
     
     
    @@ -429,17 +441,14 @@
                 <span class="n">sample</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">)</span>  <span class="c1"># additional_samples is not useful after the transform</span>
                 <span class="n">sample</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">)</span>  <span class="c1"># additional_samples is not useful after the transform</span>
             <span class="k">return</span> <span class="n">sample</span></div>
             <span class="k">return</span> <span class="n">sample</span></div>
     
     
    -    <span class="k">def</span> <span class="nf">_add_additional_inputs_for_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">An
    -                                             <span class="n">transform</span><span class="p">:</span> <span class="n">DetectionTransform</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="nf">_add_additional_inputs_for_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">An
             <span class="sd">&quot;&quot;&quot;Add additional inputs required by a transform to the sample&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;Add additional inputs required by a transform to the sample&quot;&quot;&quot;</span>
    -        <span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span>
    -                                                                                 <span class="s2">&quot;additional_samples_count&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span>
    +        <span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;additional_samples_count&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span>
             <span class="n">non_empty_annotations</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">non_empty_annotations</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;non_empty_annotations&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
             <span class="n">non_empty_annotations</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">non_empty_annotations</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;non_empty_annotations&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
             <span class="n">additional_samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_random_samples</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="p">,</span> <span class="n">non_empty_annotations</span><span class="p">)</span>
             <span class="n">additional_samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_random_samples</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="p">,</span> <span class="n">non_empty_annotations</span><span class="p">)</span>
             <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">additional_samples</span>
             <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">additional_samples</span>
     
     
    -<div class="viewcode-block" id="DetectionDataset.get_random_samples"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_random_samples">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_samples</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span><span class="p"
    -                           <span class="n">non_empty_annotations_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span 
    +<div class="viewcode-block" id="DetectionDataset.get_random_samples"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_random_samples">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_samples</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">no
             <span class="sd">&quot;&quot;&quot;Load random samples.</span>
             <span class="sd">&quot;&quot;&quot;Load random samples.</span>
     
     
     <span class="sd">        :param count: The number of samples wanted</span>
     <span class="sd">        :param count: The number of samples wanted</span>
    @@ -448,7 +457,7 @@
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">get_random_sample</span><span class="p">(</span><span class="n">non_empty_annotations_only</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">count</span><span class="p">)]</span></div>
             <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">get_random_sample</span><span class="p">(</span><span class="n">non_empty_annotations_only</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">count</span><span class="p">)]</span></div>
     
     
    -<div class="viewcode-block" id="DetectionDataset.get_random_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_random_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span> <span class="nb">bool</s
    +<div class="viewcode-block" id="DetectionDataset.get_random_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_random_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span
             <span class="k">if</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span>
                 <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_random_non_empty_annotation_available_indexes</span><span class="p">())</span>
                 <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_random_non_empty_annotation_available_indexes</span><span class="p">())</span>
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
    @@ -475,23 +484,22 @@
                     <span class="n">target_format</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">output_format</span>
                     <span class="n">target_format</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">output_format</span>
             <span class="k">return</span> <span class="n">target_format</span>
             <span class="k">return</span> <span class="n">target_format</span>
     
     
    -<div class="viewcode-block" id="DetectionDataset.plot"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.plot">[docs]</a>    <span class="k">def</span> <span class="nf">plot</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_samples_per_plot</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi"
    +<div class="viewcode-block" id="DetectionDataset.plot"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.plot">[docs]</a>    <span class="k">def</span> <span class="nf">plot</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_samples_per_plot</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">,<
             <span class="sd">&quot;&quot;&quot;Combine samples of images with bbox into plots and display the result.</span>
             <span class="sd">&quot;&quot;&quot;Combine samples of images with bbox into plots and display the result.</span>
     
     
    -<span class="sd">            :param max_samples_per_plot:    Maximum number of images to be displayed per plot</span>
    -<span class="sd">            :param n_plots:                 Number of plots to display (each plot being a combination of img with bbox)</span>
    -<span class="sd">            :param plot_transformed_data:   If True, the plot will be over samples after applying transforms (i.e. on __getitem__).</span>
    -<span class="sd">                                            If False, the plot will be over the raw samples (i.e. on get_sample)</span>
    -<span class="sd">            :return:</span>
    +<span class="sd">        :param max_samples_per_plot:    Maximum number of images to be displayed per plot</span>
    +<span class="sd">        :param n_plots:                 Number of plots to display (each plot being a combination of img with bbox)</span>
    +<span class="sd">        :param plot_transformed_data:   If True, the plot will be over samples after applying transforms (i.e. on __getitem__).</span>
    +<span class="sd">                                        If False, the plot will be over the raw samples (i.e. on get_sample)</span>
    +<span class="sd">        :return:</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="n">plot_counter</span> <span class="o">=</span> <span class="mi">0</span>
             <span class="n">plot_counter</span> <span class="o">=</span> <span class="mi">0</span>
             <span class="n">input_format</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_target_format</span> <span class="k">if</span> <span class="n">plot_transformed_data</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_target_format</span>
             <span class="n">input_format</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_target_format</span> <span class="k">if</span> <span class="n">plot_transformed_data</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_target_format</span>
    -        <span class="n">target_format_transform</span> <span class="o">=</span> <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">input_format</span><span class="o">=</span><span class="n">input_format</span><span class="p">,</span>
    -                                                                  <span class="n">output_format</span><span class="o">=</span><span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span><span class="p">)</span>
    +        <span class="n">target_format_transform</span> <span class="o">=</span> <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">input_format</span><span class="o">=</span><span class="n">input_format</span><span class="p">,</span> <span class="n">output_format</span><span class="o">=</span><span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span><span class="p">)</span>
     
     
             <span class="k">for</span> <span class="n">plot_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_plots</span><span class="p">):</span>
             <span class="k">for</span> <span class="n">plot_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_plots</span><span class="p">):</span>
                 <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
                 <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
    -            <span class="n">n_subplot</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">max_samples_per_plot</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">))</span>
    +            <span class="n">n_subplot</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">max_samples_per_plot</span><span class="o">**</span><span class="mf">0.5</span><span class="p">))</span>
                 <span class="k">for</span> <span class="n">img_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_samples_per_plot</span><span class="p">):</span>
                 <span class="k">for</span> <span class="n">img_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_samples_per_plot</span><span class="p">):</span>
                     <span class="n">index</span> <span class="o">=</span> <span class="n">img_i</span> <span class="o">+</span> <span class="n">plot_i</span> <span class="o">*</span> <span class="mi">16</span>
                     <span class="n">index</span> <span class="o">=</span> <span class="n">img_i</span> <span class="o">+</span> <span class="n">plot_i</span> <span class="o">*</span> <span class="mi">16</span>
     
     
    @@ -509,9 +517,9 @@
     
     
                     <span class="c1"># shape = [n_box x 4] (We remove padded boxes, which corresponds to boxes with only 0)</span>
                     <span class="c1"># shape = [n_box x 4] (We remove padded boxes, which corresponds to boxes with only 0)</span>
                     <span class="n">boxes</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[(</span><span class="n">boxes</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
                     <span class="n">boxes</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[(</span><span class="n">boxes</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
    -                <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">n_subplot</span><span class="p">,</span> <span class="n">n_subplot</span><span class="p">,</span> <span class="n">img_i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
    -                <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">T</span><span
    -                <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
    +                <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">n_subplot</span><span class="p">,</span> <span class="n">n_subplot</span><span class="p">,</span> <span class="n">img_i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">[:,</span> <span class="p">:,</span> <spa
    +                <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">T</span><span
    +                <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span>
                 <span class="n">fig</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span>
                 <span class="n">fig</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span>
                 <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
                 <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
                 <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
                 <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    @@ -548,4 +556,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.detection_datasets.pascal_voc_detection &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.detection_datasets.pascal_voc_detection &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -89,11 +91,15 @@
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span class="kn">import</span> <span class="nn">glob</span>
     <span class="kn">import</span> <span class="nn">glob</span>
     <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
     <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
    +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>
     <span class="kn">from</span> <span class="nn">xml.etree</span> <span class="kn">import</span> <span class="n">ElementTree</span>
     <span class="kn">from</span> <span class="nn">xml.etree</span> <span class="kn">import</span> <span class="n">ElementTree</span>
    +
    +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">ConcatDataset</span>
     <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     
     
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     
     
    +<span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionTransform</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">download_and_untar_from_url</span><span class="p">,</span> <span class="n">get_image_size_from_path</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">download_and_untar_from_url</span><span class="p">,</span> <span class="n">get_image_size_from_path</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.detection_dataset</span> <span class="kn">import</span> <span class="n">DetectionDataset</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.detection_dataset</span> <span class="kn">import</span> <span class="n">DetectionDataset</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
    @@ -103,19 +109,25 @@
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="PascalVOCDetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.PascalVOCDetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">DetectionDataset</span><span class="p">):</span>
    +<div class="viewcode-block" id="PascalVOCDetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOCDetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">DetectionDataset</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection&quot;&quot;&quot;</span>
     
     
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images_sub_directory</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images_sub_directory</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span>
             <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection</span>
             <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection</span>
     
     
     <span class="sd">        :param images_sub_directory:    Sub directory of data_dir that includes images.</span>
     <span class="sd">        :param images_sub_directory:    Sub directory of data_dir that includes images.</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    +
             <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span> <span class="o">=</span> <span class="n">images_sub_directory</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span> <span class="o">=</span> <span class="n">images_sub_directory</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="kc">None</span>
    -
    -        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;all_classes_list&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span>
    -        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;original_target_format&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span>
    +        <span class="n">data_dir</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;data_dir&quot;</span><span class="p">)</span>
    +        <span class="k">if</span> <span class="n">data_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must pass data_dir != None through **kwargs&quot;</span><span class="p">)</span>
    +        <span class="k">if</span> <span class="n">download</span><span class="p">:</span>
    +            <span class="n">PascalVOCDetectionDataset</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span>
    +
    +        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;original_target_format&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span>
    +        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;all_classes_list&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
     
     
         <span class="k">def</span> <span class="nf">_setup_data_source</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_setup_data_source</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    @@ -125,9 +137,11 @@
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="n">img_files_folder</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span>
             <span class="n">img_files_folder</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="n">Path</span><span class="p">(</span><span class="n">img_files_folder</span><span class="p">)</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="n">Path</span><span class="p">(</span><span class="n">img_files_folder</span><span class="p">)</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
    -            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> does not include </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span><span class="si">}</span><span class="s2">.
    -                                    <span class="sa">f</span><span class="s2">&quot;Please make sure that f</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> refers to PascalVOC dataset and that &quot;</span>
    -                                    <span class="s2">&quot;it was downloaded using PascalVOCDetectionDataSetV2.download()&quot;</span><span class="p">)</span>
    +            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span>
    +                <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> does not include </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span><span class="si">}</span><span class="s2">. &quot;</span>
    +                <span class="sa">f</span><span class="s2">&quot;Please make sure that f</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> refers to PascalVOC dataset and that &quot;</span>
    +                <span class="s2">&quot;it was downloaded using PascalVOCDetectionDataSetV2.download()&quot;</span>
    +            <span class="p">)</span>
     
     
             <span class="n">img_files</span> <span class="o">=</span> <span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">img_files_folder</span> <span class="o">+</span> <span class="s2">&quot;*.jpg&quot;</span><span class="p">)</span>
             <span class="n">img_files</span> <span class="o">=</span> <span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">img_files_folder</span> <span class="o">+</span> <span class="s2">&quot;*.jpg&quot;</span><span class="p">)</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    @@ -135,15 +149,13 @@
     
     
             <span class="n">target_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">img_file</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;images&quot;</span><span class="p">,</span> <span class="s2">&quot;labels&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;.jpg&quot;</span><span class="p">,</span> <span class="s2">&quo
             <span class="n">target_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">img_file</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;images&quot;</span><span class="p">,</span> <span class="s2">&quot;labels&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;.jpg&quot;</span><span class="p">,</span> <span class="s2">&quo
     
     
    -        <span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="p">[(</span><span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span><span class="p">)</span>
    -                                    <span class="k">for</span> <span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">img_files</span><span class="p">,</span> <span class="n">target_files</span><span class="p">)</span>
    -                                    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">target_file</span><span class="p">)]</span>
    +        <span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="p">[(</span><span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span><span class="p">)</span> <span class="k">for</span> <span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">img_files</span><span class="p">,</span> <span class="n">target_fi
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                 <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="s2">&quot;No target file associated to the images was found&quot;</span><span class="p">)</span>
                 <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="s2">&quot;No target file associated to the images was found&quot;</span><span class="p">)</span>
     
     
             <span class="n">num_missing_files</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
             <span class="n">num_missing_files</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
             <span class="k">if</span> <span class="n">num_missing_files</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">num_missing_files</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    -            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">num_missing_files</span><span class="si">}</span><span class="s1"> labels files were not loaded our of </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span><span class="si">}</span><span class="s1"> 
    +            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">num_missing_files</span><span class="si">}</span><span class="s2"> labels files were not loaded our of </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span><span class="si">}</span><span class="s2">
     
     
             <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="n">img_and_target_path_list</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="n">img_and_target_path_list</span>
             <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
             <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
    @@ -156,38 +168,37 @@
     <span class="sd">                    - img_path</span>
     <span class="sd">                    - img_path</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="n">img_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
             <span class="n">img_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
    -        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">targets_file</span><span class="p">:</span>
    +        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">targets_file</span><span class="p">:</span>
                 <span class="n">target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">targets_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n
                 <span class="n">target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">targets_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n
     
     
    -        <span class="n">width</span><span class="p">,</span> <span class="n">height</span> <span class="o">=</span> <span class="n">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">)</span>
    +        <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">)</span>
     
     
    -        <span class="c1"># We have to rescale the targets because the images will be rescaled.</span>
    +        <span class="c1"># We have to rescale the targets because the images will be resized.</span>
             <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</s
             <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</s
             <span class="n">target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
             <span class="n">target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
     
     
    -        <span class="n">initial_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">height</span><span class="p">)</span>
    -        <span class="n">resized_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
    +        <span class="n">resized_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
     
     
    -        <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;img_path&quot;</span><span class="p">:</span> <span class="n">img_path</span><span class="p">,</span> <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">target</span><span class="p">,</span>
    -                <span class="s2">&quot;initial_img_shape&quot;</span><span class="p">:</span> <span class="n">initial_img_shape</span><span class="p">,</span> <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">:</span> <span class="n">resized_img_shape</span><span class="p">}</span>
    +        <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;img_path&quot;</span><span class="p">:</span> <span class="n">img_path</span><span class="p">,</span> <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">target</span><span class="p">,</span> <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">:</span> <span class="n">resized_img_shape</span><span class="p">}</span>
     
     
    -<div class="viewcode-block" id="PascalVOCDetectionDataset.download"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.PascalVOCDetectionDataset.download">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="PascalVOCDetectionDataset.download"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOCDetectionDataset.download">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;Download Pascal dataset in XYXY_LABEL format.</span>
             <span class="sd">&quot;&quot;&quot;Download Pascal dataset in XYXY_LABEL format.</span>
     
     
     <span class="sd">        Data extracted form http://host.robots.ox.ac.uk/pascal/VOC/</span>
     <span class="sd">        Data extracted form http://host.robots.ox.ac.uk/pascal/VOC/</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    +
             <span class="k">def</span> <span class="nf">_parse_and_save_labels</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_id</span><span class="p">):</span>
             <span class="k">def</span> <span class="nf">_parse_and_save_labels</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_id</span><span class="p">):</span>
                 <span class="sd">&quot;&quot;&quot;Parse and save the labels of an image in XYXY_LABEL format.&quot;&quot;&quot;</span>
                 <span class="sd">&quot;&quot;&quot;Parse and save the labels of an image in XYXY_LABEL format.&quot;&quot;&quot;</span>
     
     
    -            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s1">/VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s1">/Annotations/</span><span class="si">{</span><span class="n">image_id</span><span class="si">}</span><span class="s1">.xml&#39;</span><span class="p">)</s
    +            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">/VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s2">/Annotations/</span><span class="si">{</span><span class="n">image_id</span><span class="si">}</span><span class="s2">.xml&quot;</span><span class="p">)<
                     <span class="n">xml_parser</span> <span class="o">=</span> <span class="n">ElementTree</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">getroot</span><span class="p">()</span>
                     <span class="n">xml_parser</span> <span class="o">=</span> <span class="n">ElementTree</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">getroot</span><span class="p">()</span>
     
     
                 <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span>
                 <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span>
    -            <span class="k">for</span> <span class="n">obj</span> <span class="ow">in</span> <span class="n">xml_parser</span><span class="o">.</span><span class="n">iter</span><span class="p">(</span><span class="s1">&#39;object&#39;</span><span class="p">):</span>
    -                <span class="bp">cls</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;name&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
    -                <span class="k">if</span> <span class="bp">cls</span> <span class="ow">in</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">int</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;difficult&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span>
    -                    <span class="n">xml_box</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;bndbox&#39;</span><span class="p">)</span>
    +            <span class="k">for</span> <span class="n">obj</span> <span class="ow">in</span> <span class="n">xml_parser</span><span class="o">.</span><span class="n">iter</span><span class="p">(</span><span class="s2">&quot;object&quot;</span><span class="p">):</span>
    +                <span class="bp">cls</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;name&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
    +                <span class="k">if</span> <span class="bp">cls</span> <span class="ow">in</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">int</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;difficult&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</spa
    +                    <span class="n">xml_box</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;bndbox&quot;</span><span class="p">)</span>
     
     
                         <span class="k">def</span> <span class="nf">get_coord</span><span class="p">(</span><span class="n">box_coord</span><span class="p">):</span>
                         <span class="k">def</span> <span class="nf">get_coord</span><span class="p">(</span><span class="n">box_coord</span><span class="p">):</span>
                             <span class="k">return</span> <span class="n">xml_box</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">box_coord</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
                             <span class="k">return</span> <span class="n">xml_box</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">box_coord</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
    @@ -195,33 +206,72 @@
                         <span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;xmin&quot;</span><span class="p">),</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;ymin&quot;</span><span class="p">),</span> <span class="n">get_coord<
                         <span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;xmin&quot;</span><span class="p">),</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;ymin&quot;</span><span class="p">),</span> <span class="n">get_coord<
                         <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><
                         <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><
     
     
    -            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">new_label_path</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
    +            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">new_label_path</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
                     <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">labels</span><span class="p">))</span>
                     <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">labels</span><span class="p">))</span>
     
     
    -        <span class="n">urls</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 439M 5011 images</span>
    -                <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 430M, 4952 images</span>
    -                <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar&quot;</span><span class="p">]</span>  <span class="c1"># 1.86G, 17125 images</span>
    +        <span class="n">urls</span> <span class="o">=</span> <span class="p">[</span>
    +            <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 439M 5011 images</span>
    +            <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 430M, 4952 images</span>
    +            <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar&quot;</span><span class="p">,</span>
    +        <span class="p">]</span>  <span class="c1"># 1.86G, 17125 images</span>
             <span class="n">data_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span>
             <span class="n">data_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span>
    -        <span class="n">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;images&#39;</span><span class="p">)</span>
    +        <span class="n">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;images&quot;</span><span class="p">)</span>
     
     
             <span class="c1"># Convert</span>
             <span class="c1"># Convert</span>
    -        <span class="n">data_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;images&#39;</span> <span class="o">/</span> <span class="s1">&#39;VOCdevkit&#39;</span>
    -        <span class="k">for</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_set</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">&#39;2012&#39;</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;2012&#39;</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;2007
    -            <span class="n">dest_imgs_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;images&#39;</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s1">&#39;</span>
    +        <span class="n">data_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;images&quot;</span> <span class="o">/</span> <span class="s2">&quot;VOCdevkit&quot;</span>
    +        <span class="k">for</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_set</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">&quot;2012&quot;</span><span class="p">,</span> <span class="s2">&quot;train&quot;</span><span class="p">),</span> <span class="p">(</span><span class="s2">&quot;2012&quot;</span><span class="p">,</span> <span class="s2">&quot;val&quot;</span><span class="p">),</span> <span class="p">(</span><span class="s2">&
    +            <span class="n">dest_imgs_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;images&quot;</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s2">&quot;</span>
                 <span class="n">dest_imgs_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
                 <span class="n">dest_imgs_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
     
     
    -            <span class="n">dest_labels_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;labels&#39;</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s1">&#39;</span>
    +            <span class="n">dest_labels_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;labels&quot;</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s2">&quot;</span>
                 <span class="n">dest_labels_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
                 <span class="n">dest_labels_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
     
     
    -            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s1">/ImageSets/Main/</span><span class="si">{</span><span class="n">image_set</span><span class="si">}</span><span class="s1">.txt&#39;</span><span class="p">)</span> <span class="k">as</span> <span cla
    +            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s2">/ImageSets/Main/</span><span class="si">{</span><span class="n">image_set</span><span class="si">}</span><span class="s2">.txt&quot;</span><span class="p">)</span> <span class="k">as</span> <span c
                     <span class="n">image_ids</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">()</span>
                     <span class="n">image_ids</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">()</span>
     
     
    -            <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">image_ids</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">):</span>
    -                <span class="n">img_path</span> <span class="o">=</span> <span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s1">/JPEGImages/</span><span class="si">{</span><span class="nb">id</span><span class="si">}</span><span class="s1">.jpg&#39;</span>
    +            <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">image_ids</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">):</span>
    +                <span class="n">img_path</span> <span class="o">=</span> <span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s2">/JPEGImages/</span><span class="si">{</span><span class="nb">id</span><span class="si">}</span><span class="s2">.jpg&quot;</span>
                     <span class="n">new_img_path</span> <span class="o">=</span> <span class="n">dest_imgs_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span>
                     <span class="n">new_img_path</span> <span class="o">=</span> <span class="n">dest_imgs_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span>
    -                <span class="n">new_label_path</span> <span class="o">=</span> <span class="p">(</span><span class="n">dest_labels_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span><span class="p">)</span><span class="o">.</span><span class="n">with_suffix</span><span class="p">(</span><span class="s1">&#39;.txt&#39;</span><span class="p">)</span>
    +                <span class="n">new_label_path</span> <span class="o">=</span> <span class="p">(</span><span class="n">dest_labels_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span><span class="p">)</span><span class="o">.</span><span class="n">with_suffix</span><span class="p">(</span><span class="s2">&quot;.txt&quot;</span><span class="p">)</span>
                     <span class="n">img_path</span><span class="o">.</span><span class="n">rename</span><span class="p">(</span><span class="n">new_img_path</span><span class="p">)</span>  <span class="c1"># Move image to dest folder</span>
                     <span class="n">img_path</span><span class="o">.</span><span class="n">rename</span><span class="p">(</span><span class="n">new_img_path</span><span class="p">)</span>  <span class="c1"># Move image to dest folder</span>
                     <span class="n">_parse_and_save_labels</span><span class="p">(</span><span class="n">data_path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="nb">id</span><span class="p">)</span></div></div>
                     <span class="n">_parse_and_save_labels</span><span class="p">(</span><span class="n">data_path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="nb">id</span><span class="p">)</span></div></div>
    +
    +
    +<span class="k">class</span> <span class="nc">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">(</span><span class="n">ConcatDataset</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">data_dir</span><span class="p">,</span>
    +        <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
    +        <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">cache_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">transforms</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">DetectionTransform</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span>
    +        <span class="n">class_inclusion_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">max_num_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +    <span class="p">):</span>
    +        <span class="k">if</span> <span class="n">download</span><span class="p">:</span>
    +            <span class="n">PascalVOCDetectionDataset</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="n">data_dir</span><span class="p">)</span>
    +
    +        <span class="n">train_dataset_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;train2007&quot;</span><span class="p">,</span> <span class="s2">&quot;val2007&quot;</span><span class="p">,</span> <span class="s2">&quot;train2012&quot;</span><span class="p">,</span> <span class="s2">&quot;val2012&quot;</span><span class="p">]</span>
    +        <span class="c1"># We divide train_max_num_samples between the datasets</span>
    +        <span class="k">if</span> <span class="n">max_num_samples</span><span class="p">:</span>
    +            <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">segment</span><span class="p">)</span> <span class="k">for</span> <span class="n">segment</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">max_num_sample
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)</span>
    +        <span class="n">train_sets</span> <span class="o">=</span> <span class="p">[</span>
    +            <span class="n">PascalVOCDetectionDataset</span><span class="p">(</span>
    +                <span class="n">data_dir</span><span class="o">=</span><span class="n">data_dir</span><span class="p">,</span>
    +                <span class="n">input_dim</span><span class="o">=</span><span class="n">input_dim</span><span class="p">,</span>
    +                <span class="n">cache</span><span class="o">=</span><span class="n">cache</span><span class="p">,</span>
    +                <span class="n">cache_dir</span><span class="o">=</span><span class="n">cache_dir</span><span class="p">,</span>
    +                <span class="n">transforms</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span>
    +                <span class="n">images_sub_directory</span><span class="o">=</span><span class="s2">&quot;images/&quot;</span> <span class="o">+</span> <span class="n">trainset_name</span> <span class="o">+</span> <span class="s2">&quot;/&quot;</span><span class="p">,</span>
    +                <span class="n">class_inclusion_list</span><span class="o">=</span><span class="n">class_inclusion_list</span><span class="p">,</span>
    +                <span class="n">max_num_samples</span><span class="o">=</span><span class="n">max_num_samples_per_train_dataset</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
    +            <span class="p">)</span>
    +            <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">trainset_name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)</span>
    +        <span class="p">]</span>
    +        <span class="nb">super</span><span class="p">(</span><span class="n">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">train_sets</span><span class="p">)</span>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -251,4 +301,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.mixup &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.datasets.mixup</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.datasets.mixup</h1><div class="highlight"><pre>
    83. <span></span><span class="sd">&quot;&quot;&quot; Mixup and Cutmix</span>
    84. <span class="sd">Papers:</span>
    85. <span class="sd">mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)</span>
    86. <span class="sd">CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)</span>
    87. <span class="sd">Code Reference:</span>
    88. <span class="sd">CutMix: https://github.com/clovaai/CutMix-PyTorch</span>
    89. <span class="sd">CutMix by timm: https://github.com/rwightman/pytorch-image-models/timm</span>
    90. <span class="sd">&quot;&quot;&quot;</span>
    91. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span>
    92. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    93. <span class="kn">import</span> <span class="nn">torch</span>
    94. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">IllegalDatasetParameterException</span>
    95. <div class="viewcode-block" id="one_hot"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.one_hot">[docs]</a><span class="k">def</span> <span class="nf">one_hot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">on_value</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">off_value</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
    96. <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    97. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">],</span> <span class="n">num_classes</span><span class="p">),</span> <span class="n">off_value</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">on_value</span><span class="p">)</span></div>
    98. <div class="viewcode-block" id="mixup_target"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.mixup_target">[docs]</a><span class="k">def</span> <span class="nf">mixup_target</span><span class="p">(</span><span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
    99. <span class="sd">&quot;&quot;&quot;</span>
    100. <span class="sd"> generate a smooth target (label) two-hot tensor to support the mixed images with different labels</span>
    101. <span class="sd"> :param target: the targets tensor</span>
    102. <span class="sd"> :param num_classes: number of classes (to set the final tensor size)</span>
    103. <span class="sd"> :param lam: percentage of label a range [0, 1] in the mixing</span>
    104. <span class="sd"> :param smoothing: the smoothing multiplier</span>
    105. <span class="sd"> :param device: usable device [&#39;cuda&#39;, &#39;cpu&#39;]</span>
    106. <span class="sd"> :return:</span>
    107. <span class="sd"> &quot;&quot;&quot;</span>
    108. <span class="n">off_value</span> <span class="o">=</span> <span class="n">smoothing</span> <span class="o">/</span> <span class="n">num_classes</span>
    109. <span class="n">on_value</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">smoothing</span> <span class="o">+</span> <span class="n">off_value</span>
    110. <span class="n">y1</span> <span class="o">=</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">on_value</span><span class="o">=</span><span class="n">on_value</span><span class="p">,</span> <span class="n">off_value</span><span class="o">=</span><span class="n">off_value</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    111. <span class="n">y2</span> <span class="o">=</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">on_value</span><span class="o">=</span><span class="n">on_value</span><span class="p">,</span> <span class="n">off_value</span><span class="o">=</span><span class="n">off_value</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    112. <span class="k">return</span> <span class="n">y1</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">y2</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span></div>
    113. <div class="viewcode-block" id="rand_bbox"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.rand_bbox">[docs]</a><span class="k">def</span> <span class="nf">rand_bbox</span><span class="p">(</span><span class="n">img_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">lam</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">margin</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    114. <span class="sd">&quot;&quot;&quot; Standard CutMix bounding-box</span>
    115. <span class="sd"> Generates a random square bbox based on lambda value. This impl includes</span>
    116. <span class="sd"> support for enforcing a border margin as percent of bbox dimensions.</span>
    117. <span class="sd"> :param img_shape: Image shape as tuple</span>
    118. <span class="sd"> :param lam: Cutmix lambda value</span>
    119. <span class="sd"> :param margin: Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)</span>
    120. <span class="sd"> :param count: Number of bbox to generate</span>
    121. <span class="sd"> &quot;&quot;&quot;</span>
    122. <span class="n">ratio</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
    123. <span class="n">img_h</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">=</span> <span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span>
    124. <span class="n">cut_h</span><span class="p">,</span> <span class="n">cut_w</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_h</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_w</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">)</span>
    125. <span class="n">margin_y</span><span class="p">,</span> <span class="n">margin_x</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">margin</span> <span class="o">*</span> <span class="n">cut_h</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">margin</span> <span class="o">*</span> <span class="n">cut_w</span><span class="p">)</span>
    126. <span class="n">cy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span> <span class="o">+</span> <span class="n">margin_y</span><span class="p">,</span> <span class="n">img_h</span> <span class="o">-</span> <span class="n">margin_y</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    127. <span class="n">cx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span> <span class="o">+</span> <span class="n">margin_x</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">-</span> <span class="n">margin_x</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    128. <span class="n">yl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cy</span> <span class="o">-</span> <span class="n">cut_h</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_h</span><span class="p">)</span>
    129. <span class="n">yh</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cy</span> <span class="o">+</span> <span class="n">cut_h</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_h</span><span class="p">)</span>
    130. <span class="n">xl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cx</span> <span class="o">-</span> <span class="n">cut_w</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_w</span><span class="p">)</span>
    131. <span class="n">xh</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cx</span> <span class="o">+</span> <span class="n">cut_w</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_w</span><span class="p">)</span>
    132. <span class="k">return</span> <span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span></div>
    133. <div class="viewcode-block" id="rand_bbox_minmax"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.rand_bbox_minmax">[docs]</a><span class="k">def</span> <span class="nf">rand_bbox_minmax</span><span class="p">(</span><span class="n">img_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">minmax</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">],</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    134. <span class="sd">&quot;&quot;&quot; Min-Max CutMix bounding-box</span>
    135. <span class="sd"> Inspired by Darknet cutmix impl, generates a random rectangular bbox</span>
    136. <span class="sd"> based on min/max percent values applied to each dimension of the input image.</span>
    137. <span class="sd"> Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.</span>
    138. <span class="sd"> :param img_shape: Image shape as tuple</span>
    139. <span class="sd"> :param minmax: Min and max bbox ratios (as percent of image size)</span>
    140. <span class="sd"> :param count: Number of bbox to generate</span>
    141. <span class="sd"> &quot;&quot;&quot;</span>
    142. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">minmax</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
    143. <span class="n">img_h</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">=</span> <span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span>
    144. <span class="n">cut_h</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img_h</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_h</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    145. <span class="n">cut_w</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img_w</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_w</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    146. <span class="n">yl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">img_h</span> <span class="o">-</span> <span class="n">cut_h</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    147. <span class="n">xl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">-</span> <span class="n">cut_w</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    148. <span class="n">yu</span> <span class="o">=</span> <span class="n">yl</span> <span class="o">+</span> <span class="n">cut_h</span>
    149. <span class="n">xu</span> <span class="o">=</span> <span class="n">xl</span> <span class="o">+</span> <span class="n">cut_w</span>
    150. <span class="k">return</span> <span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span></div>
    151. <div class="viewcode-block" id="cutmix_bbox_and_lam"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.cutmix_bbox_and_lam">[docs]</a><span class="k">def</span> <span class="nf">cutmix_bbox_and_lam</span><span class="p">(</span><span class="n">img_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">lam</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">correct_lam</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    152. <span class="n">count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    153. <span class="sd">&quot;&quot;&quot;</span>
    154. <span class="sd"> Generate bbox and apply lambda correction.</span>
    155. <span class="sd"> &quot;&quot;&quot;</span>
    156. <span class="k">if</span> <span class="n">ratio_minmax</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    157. <span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span> <span class="o">=</span> <span class="n">rand_bbox_minmax</span><span class="p">(</span><span class="n">img_shape</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="p">,</span> <span class="n">count</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    158. <span class="k">else</span><span class="p">:</span>
    159. <span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span> <span class="o">=</span> <span class="n">rand_bbox</span><span class="p">(</span><span class="n">img_shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">count</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
    160. <span class="k">if</span> <span class="n">correct_lam</span> <span class="ow">or</span> <span class="n">ratio_minmax</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    161. <span class="n">bbox_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">yu</span> <span class="o">-</span> <span class="n">yl</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">xu</span> <span class="o">-</span> <span class="n">xl</span><span class="p">)</span>
    162. <span class="n">lam</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">bbox_area</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
    163. <span class="k">return</span> <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span><span class="p">),</span> <span class="n">lam</span></div>
    164. <div class="viewcode-block" id="CollateMixup"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.CollateMixup">[docs]</a><span class="k">class</span> <span class="nc">CollateMixup</span><span class="p">:</span>
    165. <span class="sd">&quot;&quot;&quot;</span>
    166. <span class="sd"> Collate with Mixup/Cutmix that applies different params to each element or whole batch</span>
    167. <span class="sd"> A Mixup impl that&#39;s performed while collating the batches.</span>
    168. <span class="sd"> &quot;&quot;&quot;</span>
    169. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mixup_alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">cutmix_alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">cutmix_minmax</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    170. <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">switch_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    171. <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;batch&#39;</span><span class="p">,</span> <span class="n">correct_lam</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">label_smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1000</span><span class="p">):</span>
    172. <span class="sd">&quot;&quot;&quot;</span>
    173. <span class="sd"> Mixup/Cutmix that applies different params to each element or whole batch</span>
    174. <span class="sd"> :param mixup_alpha: mixup alpha value, mixup is active if &gt; 0.</span>
    175. <span class="sd"> :param cutmix_alpha: cutmix alpha value, cutmix is active if &gt; 0.</span>
    176. <span class="sd"> :param cutmix_minmax: cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.</span>
    177. <span class="sd"> :param prob: probability of applying mixup or cutmix per batch or element</span>
    178. <span class="sd"> :param switch_prob: probability of switching to cutmix instead of mixup when both are active</span>
    179. <span class="sd"> :param mode: how to apply mixup/cutmix params (per &#39;batch&#39;, &#39;pair&#39; (pair of elements), &#39;elem&#39; (element)</span>
    180. <span class="sd"> :param correct_lam: apply lambda correction when cutmix bbox clipped by image borders</span>
    181. <span class="sd"> :param label_smoothing: apply label smoothing to the mixed target tensor</span>
    182. <span class="sd"> :param num_classes: number of classes for target</span>
    183. <span class="sd"> &quot;&quot;&quot;</span>
    184. <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">=</span> <span class="n">mixup_alpha</span>
    185. <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">=</span> <span class="n">cutmix_alpha</span>
    186. <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span> <span class="o">=</span> <span class="n">cutmix_minmax</span>
    187. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    188. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
    189. <span class="c1"># force cutmix alpha == 1.0 when minmax active to keep logic simple &amp; safe</span>
    190. <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">=</span> <span class="mf">1.0</span>
    191. <span class="bp">self</span><span class="o">.</span><span class="n">mix_prob</span> <span class="o">=</span> <span class="n">prob</span>
    192. <span class="bp">self</span><span class="o">.</span><span class="n">switch_prob</span> <span class="o">=</span> <span class="n">switch_prob</span>
    193. <span class="bp">self</span><span class="o">.</span><span class="n">label_smoothing</span> <span class="o">=</span> <span class="n">label_smoothing</span>
    194. <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span>
    195. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
    196. <span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span> <span class="o">=</span> <span class="n">correct_lam</span> <span class="c1"># correct lambda based on clipped area for cutmix</span>
    197. <span class="bp">self</span><span class="o">.</span><span class="n">mixup_enabled</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># set to false to disable mixing (intended tp be set by train loop)</span>
    198. <span class="k">def</span> <span class="nf">_params_per_elem</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span>
    199. <span class="sd">&quot;&quot;&quot;</span>
    200. <span class="sd"> generate two random masks to define which elements of the batch will be mixed and how (depending on the</span>
    201. <span class="sd"> self.mixup_enabled, self.mixup_alpha, self.cutmix_alpha parameters</span>
    202. <span class="sd"> :param batch_size:</span>
    203. <span class="sd"> :return: two tensors with shape=batch_size - the first contains the lambda value per batch element</span>
    204. <span class="sd"> and the second is a binary flag indicating use of cutmix per batch element</span>
    205. <span class="sd"> &quot;&quot;&quot;</span>
    206. <span class="n">lam</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    207. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
    208. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_enabled</span><span class="p">:</span>
    209. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
    210. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch_prob</span>
    211. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span>
    212. <span class="n">use_cutmix</span><span class="p">,</span>
    213. <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">),</span>
    214. <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">))</span>
    215. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
    216. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
    217. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
    218. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
    219. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
    220. <span class="k">else</span><span class="p">:</span>
    221. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;One of mixup_alpha &gt; 0., cutmix_alpha &gt; 0., &quot;</span>
    222. <span class="s2">&quot;cutmix_minmax not None should be true.&quot;</span><span class="p">)</span>
    223. <span class="n">lam</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">mix_prob</span><span class="p">,</span> <span class="n">lam_mix</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">lam</span><span class="p">)</span>
    224. <span class="k">return</span> <span class="n">lam</span><span class="p">,</span> <span class="n">use_cutmix</span>
    225. <span class="k">def</span> <span class="nf">_params_per_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    226. <span class="sd">&quot;&quot;&quot;</span>
    227. <span class="sd"> generate two random parameters to define if batch will be mixed and how (depending on the</span>
    228. <span class="sd"> self.mixup_enabled, self.mixup_alpha, self.cutmix_alpha parameters</span>
    229. <span class="sd"> :return: two parameters - the first contains the lambda value for the whole batch</span>
    230. <span class="sd"> and the second is a binary flag indicating use of cutmix for the batch</span>
    231. <span class="sd"> &quot;&quot;&quot;</span>
    232. <span class="n">lam</span> <span class="o">=</span> <span class="mf">1.</span>
    233. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="kc">False</span>
    234. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_enabled</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">mix_prob</span><span class="p">:</span>
    235. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
    236. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch_prob</span>
    237. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span> <span class="k">if</span> <span class="n">use_cutmix</span> <span class="k">else</span> \
    238. <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
    239. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
    240. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
    241. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
    242. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="kc">True</span>
    243. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
    244. <span class="k">else</span><span class="p">:</span>
    245. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;One of mixup_alpha &gt; 0., cutmix_alpha &gt; 0., &quot;</span>
    246. <span class="s2">&quot;cutmix_minmax not None should be true.&quot;</span><span class="p">)</span>
    247. <span class="n">lam</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">lam_mix</span><span class="p">)</span>
    248. <span class="k">return</span> <span class="n">lam</span><span class="p">,</span> <span class="n">use_cutmix</span>
    249. <span class="k">def</span> <span class="nf">_mix_elem_collate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">half</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    250. <span class="sd">&quot;&quot;&quot;</span>
    251. <span class="sd"> This is the implementation for &#39;elem&#39; or &#39;half&#39; modes</span>
    252. <span class="sd"> :param output: the output tensor to fill</span>
    253. <span class="sd"> :param batch: list of thr batch items</span>
    254. <span class="sd"> :return: a tensor containing the lambda values used for the mixing (this vector can be used for</span>
    255. <span class="sd"> mixing the labels as well)</span>
    256. <span class="sd"> &quot;&quot;&quot;</span>
    257. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
    258. <span class="n">num_elem</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">//</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">half</span> <span class="k">else</span> <span class="n">batch_size</span>
    259. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">output</span><span class="p">)</span> <span class="o">==</span> <span class="n">num_elem</span>
    260. <span class="n">lam_batch</span><span class="p">,</span> <span class="n">use_cutmix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params_per_elem</span><span class="p">(</span><span class="n">num_elem</span><span class="p">)</span>
    261. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_elem</span><span class="p">):</span>
    262. <span class="n">j</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">-</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span>
    263. <span class="n">lam</span> <span class="o">=</span> <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    264. <span class="n">mixed</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
    265. <span class="k">if</span> <span class="n">lam</span> <span class="o">!=</span> <span class="mf">1.</span><span class="p">:</span>
    266. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span>
    267. <span class="k">if</span> <span class="ow">not</span> <span class="n">half</span><span class="p">:</span>
    268. <span class="n">mixed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clone</span><span class="p">(</span><span class="n">mixed</span><span class="p">)</span>
    269. <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span><span class="p">),</span> <span class="n">lam</span> <span class="o">=</span> <span class="n">cutmix_bbox_and_lam</span><span class="p">(</span>
    270. <span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">,</span> <span class="n">correct_lam</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span><span class="p">)</span>
    271. <span class="n">mixed</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">][:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span>
    272. <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">lam</span>
    273. <span class="k">else</span><span class="p">:</span>
    274. <span class="n">mixed</span> <span class="o">=</span> <span class="n">mixed</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
    275. <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed</span>
    276. <span class="k">if</span> <span class="n">half</span><span class="p">:</span>
    277. <span class="n">lam_batch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">lam_batch</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_elem</span><span class="p">)))</span>
    278. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">lam_batch</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    279. <span class="k">def</span> <span class="nf">_mix_pair_collate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
    280. <span class="sd">&quot;&quot;&quot;</span>
    281. <span class="sd"> This is the implementation for &#39;pair&#39; mode</span>
    282. <span class="sd"> :param output: the output tensor to fill</span>
    283. <span class="sd"> :param batch: list of thr batch items</span>
    284. <span class="sd"> :return: a tensor containing the lambda values used for the mixing (this vector can be used for</span>
    285. <span class="sd"> mixing the labels as well)</span>
    286. <span class="sd"> &quot;&quot;&quot;</span>
    287. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
    288. <span class="n">lam_batch</span><span class="p">,</span> <span class="n">use_cutmix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params_per_elem</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
    289. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">):</span>
    290. <span class="n">j</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">-</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span>
    291. <span class="n">lam</span> <span class="o">=</span> <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    292. <span class="n">mixed_i</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
    293. <span class="n">mixed_j</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
    294. <span class="k">assert</span> <span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">lam</span> <span class="o">&lt;=</span> <span class="mf">1.0</span>
    295. <span class="k">if</span> <span class="n">lam</span> <span class="o">&lt;</span> <span class="mf">1.</span><span class="p">:</span>
    296. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span>
    297. <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span><span class="p">),</span> <span class="n">lam</span> <span class="o">=</span> <span class="n">cutmix_bbox_and_lam</span><span class="p">(</span>
    298. <span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">,</span> <span class="n">correct_lam</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span><span class="p">)</span>
    299. <span class="n">patch_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clone</span><span class="p">(</span><span class="n">mixed_i</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">])</span>
    300. <span class="n">mixed_i</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">mixed_j</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span>
    301. <span class="n">mixed_j</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">patch_i</span>
    302. <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">lam</span>
    303. <span class="k">else</span><span class="p">:</span>
    304. <span class="n">mixed_temp</span> <span class="o">=</span> <span class="n">mixed_i</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">mixed_j</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
    305. <span class="n">mixed_j</span> <span class="o">=</span> <span class="n">mixed_j</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">mixed_i</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
    306. <span class="n">mixed_i</span> <span class="o">=</span> <span class="n">mixed_temp</span>
    307. <span class="n">torch</span><span class="o">.</span><span class="n">rint</span><span class="p">(</span><span class="n">mixed_j</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mixed_j</span><span class="p">)</span>
    308. <span class="n">torch</span><span class="o">.</span><span class="n">rint</span><span class="p">(</span><span class="n">mixed_i</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mixed_i</span><span class="p">)</span>
    309. <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed_i</span>
    310. <span class="n">output</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed_j</span>
    311. <span class="n">lam_batch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">lam_batch</span><span class="p">,</span> <span class="n">lam_batch</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
    312. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">lam_batch</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    313. <span class="k">def</span> <span class="nf">_mix_batch_collate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
    314. <span class="sd">&quot;&quot;&quot;</span>
    315. <span class="sd"> This is the implementation for &#39;batch&#39; mode</span>
    316. <span class="sd"> :param output: the output tensor to fill</span>
    317. <span class="sd"> :param batch: list of thr batch items</span>
    318. <span class="sd"> :return: the lambda value used for the mixing</span>
    319. <span class="sd"> &quot;&quot;&quot;</span>
    320. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
    321. <span class="n">lam</span><span class="p">,</span> <span class="n">use_cutmix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params_per_batch</span><span class="p">()</span>
    322. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">:</span>
    323. <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span><span class="p">),</span> <span class="n">lam</span> <span class="o">=</span> <span class="n">cutmix_bbox_and_lam</span><span class="p">(</span>
    324. <span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">,</span> <span class="n">correct_lam</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span><span class="p">)</span>
    325. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
    326. <span class="n">j</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">-</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span>
    327. <span class="n">mixed</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
    328. <span class="k">if</span> <span class="n">lam</span> <span class="o">!=</span> <span class="mf">1.</span><span class="p">:</span>
    329. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">:</span>
    330. <span class="n">mixed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clone</span><span class="p">(</span><span class="n">mixed</span><span class="p">)</span> <span class="c1"># don&#39;t want to modify the original while iterating</span>
    331. <span class="n">mixed</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">][:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span>
    332. <span class="k">else</span><span class="p">:</span>
    333. <span class="n">mixed</span> <span class="o">=</span> <span class="n">mixed</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
    334. <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed</span>
    335. <span class="k">return</span> <span class="n">lam</span>
    336. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">_</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    337. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
    338. <span class="k">if</span> <span class="n">batch_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
    339. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s1">&#39;Batch size should be even when using this&#39;</span><span class="p">)</span>
    340. <span class="n">half</span> <span class="o">=</span> <span class="s1">&#39;half&#39;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span>
    341. <span class="k">if</span> <span class="n">half</span><span class="p">:</span>
    342. <span class="n">batch_size</span> <span class="o">//=</span> <span class="mi">2</span>
    343. <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">*</span><span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    344. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;elem&#39;</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;half&#39;</span><span class="p">:</span>
    345. <span class="n">lam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mix_elem_collate</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">half</span><span class="o">=</span><span class="n">half</span><span class="p">)</span>
    346. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;pair&#39;</span><span class="p">:</span>
    347. <span class="n">lam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mix_pair_collate</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
    348. <span class="k">else</span><span class="p">:</span>
    349. <span class="n">lam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mix_batch_collate</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
    350. <span class="n">target</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">b</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
    351. <span class="n">target</span> <span class="o">=</span> <span class="n">mixup_target</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_smoothing</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
    352. <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="p">[:</span><span class="n">batch_size</span><span class="p">]</span>
    353. <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span></div>
    354. </pre></div>
    355. </div>
    356. </div>
    357. <footer>
    358. <hr/>
    359. <div role="contentinfo">
    360. <p>&#169; Copyright 2021, SuperGradients team.</p>
    361. </div>
    362. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    363. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    364. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    365. </footer>
    366. </div>
    367. </div>
    368. </section>
    369. </div>
    370. <script>
    371. jQuery(function () {
    372. SphinxRtdTheme.Navigation.enable(true);
    373. });
    374. </script>
    375. </body>
    376. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../../_static/jquery.js"></script>
    15. <script src="../../../../../_static/underscore.js"></script>
    16. <script src="../../../../../_static/doctools.js"></script>
    17. <script src="../../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">os</span>
    84. <span class="kn">import</span> <span class="nn">cv2</span>
    85. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    86. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageColor</span>
    87. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
    88. <span class="c1"># TODO - ADD COARSE DATA - right now cityscapes dataset includes fine annotations. It&#39;s optional to use extra coarse</span>
    89. <span class="c1"># annotations.</span>
    90. <span class="c1"># label for background and labels to ignore during training and evaluation.</span>
    91. <span class="n">CITYSCAPES_IGNORE_LABEL</span> <span class="o">=</span> <span class="mi">19</span>
    92. <div class="viewcode-block" id="CityscapesDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset">[docs]</a><span class="k">class</span> <span class="nc">CityscapesDataset</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
    93. <span class="sd">&quot;&quot;&quot;</span>
    94. <span class="sd"> CityscapesDataset - Segmentation Data Set Class for Cityscapes Segmentation Data Set,</span>
    95. <span class="sd"> main resolution of dataset: (2048 x 1024).</span>
    96. <span class="sd"> Not all the original labels are used for training and evaluation, according to cityscape paper:</span>
    97. <span class="sd"> &quot;Classes that are too rare are excluded from our benchmark, leaving 19 classes for evaluation&quot;.</span>
    98. <span class="sd"> For more details about the dataset labels format see:</span>
    99. <span class="sd"> https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py</span>
    100. <span class="sd"> &quot;&quot;&quot;</span>
    101. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    102. <span class="n">root_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    103. <span class="n">list_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    104. <span class="n">labels_csv_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    105. <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    106. <span class="sd">&quot;&quot;&quot;</span>
    107. <span class="sd"> :param root: root directory to dataset.</span>
    108. <span class="sd"> :param list_file: list file that contains names of images to load, line format: &lt;image_path&gt; &lt;label_path&gt;</span>
    109. <span class="sd"> :param labels_csv_path: path to csv file, with labels metadata and mapping.</span>
    110. <span class="sd"> :param kwargs: Any hyper params required for the dataset, i.e img_size, crop_size, cache_images</span>
    111. <span class="sd"> &quot;&quot;&quot;</span>
    112. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">root_dir</span>
    113. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">root_dir</span><span class="p">,</span> <span class="n">list_file</span><span class="o">=</span><span class="n">list_file</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    114. <span class="c1"># labels dataframe for labels metadata.</span>
    115. <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">recfromcsv</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span> <span class="n">labels_csv_path</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;&lt;i8,U20,&lt;i8,&lt;i8,U12,&lt;i8,?,?,U7&#39;</span><span class="p">,</span> <span class="n">comments</span><span class="o">=</span><span class="s1">&#39;&amp;&#39;</span><span class="p">)</span>
    116. <span class="c1"># map vector to map ground-truth labels to train labels</span>
    117. <span class="bp">self</span><span class="o">.</span><span class="n">labels_map</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;trainid&quot;</span><span class="p">)</span>
    118. <span class="c1"># class names</span>
    119. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;name&quot;</span><span class="p">)[</span><span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;ignoreineval&quot;</span><span class="p">))]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
    120. <span class="c1"># color palette for visualization</span>
    121. <span class="bp">self</span><span class="o">.</span><span class="n">train_id_color_palette</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_color_palette</span><span class="p">()</span>
    122. <span class="k">def</span> <span class="nf">_generate_samples_and_targets</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    123. <span class="sd">&quot;&quot;&quot;</span>
    124. <span class="sd"> override _generate_samples_and_targets function, to parse list file.</span>
    125. <span class="sd"> line format of list file: &lt;image_path&gt; &lt;label_path&gt;</span>
    126. <span class="sd"> &quot;&quot;&quot;</span>
    127. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_file_path</span><span class="p">))</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
    128. <span class="n">img_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">line</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">f</span><span class="p">]</span>
    129. <span class="k">for</span> <span class="n">image_path</span><span class="p">,</span> <span class="n">label_path</span> <span class="ow">in</span> <span class="n">img_list</span><span class="p">:</span>
    130. <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span>
    131. <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">image_path</span><span class="p">),</span>
    132. <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">label_path</span><span class="p">)</span>
    133. <span class="p">))</span>
    134. <div class="viewcode-block" id="CityscapesDataset.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset.target_loader">[docs]</a> <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">label_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
    135. <span class="sd">&quot;&quot;&quot;</span>
    136. <span class="sd"> Override target_loader function, load the labels mask image.</span>
    137. <span class="sd"> :param label_path: Path to the label image.</span>
    138. <span class="sd"> :return: The mask image created from the array, with converted class labels.</span>
    139. <span class="sd"> &quot;&quot;&quot;</span>
    140. <span class="c1"># assert that is a png file, other file types might alter the class labels value.</span>
    141. <span class="k">assert</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">splitext</span><span class="p">(</span><span class="n">label_path</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="o">==</span> <span class="s2">&quot;.png&quot;</span>
    142. <span class="n">label</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">label_path</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">IMREAD_GRAYSCALE</span><span class="p">)</span>
    143. <span class="c1"># map ground-truth ids to train ids</span>
    144. <span class="n">label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_map</span><span class="p">[</span><span class="n">label</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
    145. <span class="k">return</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="s1">&#39;L&#39;</span><span class="p">)</span></div>
    146. <span class="k">def</span> <span class="nf">_create_color_palette</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    147. <span class="sd">&quot;&quot;&quot;</span>
    148. <span class="sd"> Create color pallete for visualizing the segmentation masks</span>
    149. <span class="sd"> :return: list of rgb color values</span>
    150. <span class="sd"> &quot;&quot;&quot;</span>
    151. <span class="n">palette</span> <span class="o">=</span> <span class="p">[]</span>
    152. <span class="n">hex_colors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;color&quot;</span><span class="p">)[</span><span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;ignoreineval&quot;</span><span class="p">))]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
    153. <span class="k">for</span> <span class="n">hex_color</span> <span class="ow">in</span> <span class="n">hex_colors</span><span class="p">:</span>
    154. <span class="n">rgb_color</span> <span class="o">=</span> <span class="n">ImageColor</span><span class="o">.</span><span class="n">getcolor</span><span class="p">(</span><span class="n">hex_color</span><span class="p">,</span> <span class="s2">&quot;RGB&quot;</span><span class="p">)</span>
    155. <span class="n">palette</span> <span class="o">+=</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">rgb_color</span><span class="p">]</span>
    156. <span class="k">return</span> <span class="n">palette</span>
    157. <div class="viewcode-block" id="CityscapesDataset.get_train_ids_color_palette"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset.get_train_ids_color_palette">[docs]</a> <span class="k">def</span> <span class="nf">get_train_ids_color_palette</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    158. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_id_color_palette</span></div>
    159. <div class="viewcode-block" id="CityscapesDataset.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset.target_transform">[docs]</a> <span class="nd">@staticmethod</span>
    160. <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
    161. <span class="sd">&quot;&quot;&quot;</span>
    162. <span class="sd"> target_transform - Transforms the sample image</span>
    163. <span class="sd"> This function overrides the original function from SegmentationDataSet and changes target pixels with value</span>
    164. <span class="sd"> 255 to value = CITYSCAPES_IGNORE_LABEL. This was done since current IoU metric from torchmetrics does not</span>
    165. <span class="sd"> support such a high ignore label value (crashed on OOM)</span>
    166. <span class="sd"> :param target: The target mask to transform</span>
    167. <span class="sd"> :return: The transformed target mask</span>
    168. <span class="sd"> &quot;&quot;&quot;</span>
    169. <span class="n">out</span> <span class="o">=</span> <span class="n">SegmentationDataSet</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
    170. <span class="n">out</span><span class="p">[</span><span class="n">out</span> <span class="o">==</span> <span class="mi">255</span><span class="p">]</span> <span class="o">=</span> <span class="n">CITYSCAPES_IGNORE_LABEL</span>
    171. <span class="k">return</span> <span class="n">out</span></div></div>
    172. </pre></div>
    173. </div>
    174. </div>
    175. <footer>
    176. <hr/>
    177. <div role="contentinfo">
    178. <p>&#169; Copyright 2021, SuperGradients team.</p>
    179. </div>
    180. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    181. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    182. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    183. </footer>
    184. </div>
    185. </div>
    186. </section>
    187. </div>
    188. <script>
    189. jQuery(function () {
    190. SphinxRtdTheme.Navigation.enable(true);
    191. });
    192. </script>
    193. </body>
    194. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.segmentation_datasets.coco_segmentation &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.segmentation_datasets.coco_segmentation &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -87,14 +89,11 @@
                  
                  
       <h1>Source code for super_gradients.training.datasets.segmentation_datasets.coco_segmentation</h1><div class="highlight"><pre>
       <h1>Source code for super_gradients.training.datasets.segmentation_datasets.coco_segmentation</h1><div class="highlight"><pre>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
    -<span class="kn">import</span> <span class="nn">torch</span>
    +
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    -<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
    +<span class="kn">import</span> <span class="nn">torch</span>
     <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
     <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
    -<span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transform</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">RandomFlip</span><span class="p">,</span> <span class="n">Rescale</span><span class="p">,</span> <span class="n">RandomRescale</span><span class="p">,</span> <span class="n">CropImageAndMask</span><span class="p">,</span> \
    -    <span class="n">PadShortToCropSize</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
    +<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     
     
     <span class="k">try</span><span class="p">:</span>
     <span class="k">try</span><span class="p">:</span>
         <span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
         <span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
    @@ -106,36 +105,22 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
     
     
     
     
    -<div class="viewcode-block" id="EmptyCoCoClassesSelectionException"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.EmptyCoCoClassesSelectionException">[docs]</a><span class="k">class</span> <span class="nc">EmptyCoCoClassesSelectionException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    -    <span class="k">pass</span></div>
    +<span class="k">class</span> <span class="nc">EmptyCoCoClassesSelectionException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    +    <span class="k">pass</span>
     
     
     
     
    -<div class="viewcode-block" id="CoCoSegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.CoCoSegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">CoCoSegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
    +<div class="viewcode-block" id="CoCoSegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.CoCoSegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">CoCoSegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    CoCoSegmentationDataSet - Segmentation Data Set Class for COCO 2017 Segmentation Data Set</span>
     <span class="sd">    CoCoSegmentationDataSet - Segmentation Data Set Class for COCO 2017 Segmentation Data Set</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    +                 <span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
             <span class="c1"># THERE ARE 91 CLASSES, INCLUDING BACKGROUND - BUT WE ENABLE THE USAGE OF SUBCLASSES, TO PARTIALLY USE THE DATA</span>
             <span class="c1"># THERE ARE 91 CLASSES, INCLUDING BACKGROUND - BUT WE ENABLE THE USAGE OF SUBCLASSES, TO PARTIALLY USE THE DATA</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span> <span class="o">=</span> <span class="n">dataset_classes_inclusion_tuples_list</span> <span class="ow">or</span> <span class="n">COCO_DEFAULT_CLASSES_TUPLES_LIST</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span> <span class="o">=</span> <span class="n">dataset_classes_inclusion_tuples_list</span> <span class="ow">or</span> <span class="n">COCO_DEFAULT_CLASSES_TUPLES_LIST</span>
     
     
    -        <span class="c1"># OVERRIDE DEFAULT AUGMENTATIONS, IMG_SIZE, CROP SIZE</span>
    -        <span class="n">dataset_hyper_params</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;dataset_hyper_params&#39;</span><span class="p">]</span>
    -        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;img_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s1">&#39;img_size&#39;</span><span class="p">,</span> <span class="mi">608</span><span class="p">)</span>
    -        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;crop_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="mi">512</span><span class="p">)</span>
    -        <span class="c1"># FIXME - Rescale before RandomRescale is kept for legacy support, consider removing it like most implementation</span>
    -        <span class="c1">#  papers regimes.</span>
    -        <span class="n">default_image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">RandomFlip</span><span class="p">(),</span>
    -                                                               <span class="n">Rescale</span><span class="p">(</span><span class="n">long_size</span><span class="o">=</span><span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;img_size&quot;</span><span class="p">]),</span>
    -                                                               <span class="n">RandomRescale</span><span class="p">(</span><span class="n">scales</span><span class="o">=</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)),</span>
    -                                                               <span class="n">PadShortToCropSize</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;crop_size&#39;</span><span class="p">]),</span>
    -                                                               <span class="n">CropImageAndMask</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;crop_size&#39;</span><span class="p">],</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;random&quot;</span><span class="p">)])</span>
    -        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">,</span>
    -                                                        <span class="n">default_image_mask_transforms_aug</span><span class="p">)</span>
    -        <span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s1">&#39;image_mask_transforms&#39;</span><span class="p">)</span>
    -        <span class="k">if</span> <span class="n">image_mask_transforms</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;image_mask_transforms&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image_mask_transforms</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">root_dir</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">root_dir</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
     
     
             <span class="n">_</span><span class="p">,</span> <span class="n">class_names</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
             <span class="n">_</span><span class="p">,</span> <span class="n">class_names</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">class_names</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">class_names</span>
    @@ -162,7 +147,9 @@
                 <span class="n">mask_metadata_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="n">relevant_image_id</span><span class="p">,</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;height&#39;</span><span class="p">],</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;width&#39;</span><span class="p">])</span>
                 <span class="n">mask_metadata_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="n">relevant_image_id</span><span class="p">,</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;height&#39;</span><span class="p">],</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;width&#39;</span><span class="p">])</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">))</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">))</span>
     
     
    -<div class="viewcode-block" id="CoCoSegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.CoCoSegmentationDataSet.target_loader">[docs]</a>    <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">)</span> <span class="o">-&gt;</spa
    +        <span class="nb">super</span><span class="p">(</span><span class="n">CoCoSegmentationDataSet</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span>
    +
    +<div class="viewcode-block" id="CoCoSegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.CoCoSegmentationDataSet.target_loader">[docs]</a>    <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        target_loader</span>
     <span class="sd">        target_loader</span>
     <span class="sd">            :param mask_metadata_tuple:  A tuple of (coco_image_id, original_image_height, original_image_width)</span>
     <span class="sd">            :param mask_metadata_tuple:  A tuple of (coco_image_id, original_image_height, original_image_width)</span>
    @@ -267,4 +254,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../../_static/jquery.js"></script>
    15. <script src="../../../../../_static/underscore.js"></script>
    16. <script src="../../../../../_static/doctools.js"></script>
    17. <script src="../../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">scipy.io</span>
    84. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation</span> <span class="kn">import</span> <span class="n">PascalVOC2012SegmentationDataSet</span>
    86. <span class="n">PASCAL_AUG_CLASSES</span> <span class="o">=</span> <span class="p">[</span>
    87. <span class="s1">&#39;background&#39;</span><span class="p">,</span> <span class="s1">&#39;airplane&#39;</span><span class="p">,</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">,</span> <span class="s1">&#39;bird&#39;</span><span class="p">,</span> <span class="s1">&#39;boat&#39;</span><span class="p">,</span> <span class="s1">&#39;bottle&#39;</span><span class="p">,</span>
    88. <span class="s1">&#39;bus&#39;</span><span class="p">,</span> <span class="s1">&#39;car&#39;</span><span class="p">,</span> <span class="s1">&#39;cat&#39;</span><span class="p">,</span> <span class="s1">&#39;chair&#39;</span><span class="p">,</span> <span class="s1">&#39;cow&#39;</span><span class="p">,</span> <span class="s1">&#39;diningtable&#39;</span><span class="p">,</span> <span class="s1">&#39;dog&#39;</span><span class="p">,</span> <span class="s1">&#39;horse&#39;</span><span class="p">,</span>
    89. <span class="s1">&#39;motorcycle&#39;</span><span class="p">,</span> <span class="s1">&#39;person&#39;</span><span class="p">,</span> <span class="s1">&#39;potted-plant&#39;</span><span class="p">,</span> <span class="s1">&#39;sheep&#39;</span><span class="p">,</span> <span class="s1">&#39;sofa&#39;</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">,</span>
    90. <span class="s1">&#39;tv&#39;</span>
    91. <span class="p">]</span>
    92. <div class="viewcode-block" id="PascalAUG2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalAUG2012SegmentationDataSet</span><span class="p">(</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">):</span>
    93. <span class="sd">&quot;&quot;&quot;</span>
    94. <span class="sd"> PascalAUG2012SegmentationDataSet - Segmentation Data Set Class for Pascal AUG 2012 Data Set</span>
    95. <span class="sd"> &quot;&quot;&quot;</span>
    96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    97. <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s1">&#39;.jpg&#39;</span>
    98. <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s1">&#39;.mat&#39;</span>
    99. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">sample_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    100. <span class="c1"># THERE ARE 21 CLASSES, INCLUDING BACKGROUND</span>
    101. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">PASCAL_AUG_CLASSES</span>
    102. <div class="viewcode-block" id="PascalAUG2012SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet.target_loader">[docs]</a> <span class="nd">@staticmethod</span>
    103. <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
    104. <span class="sd">&quot;&quot;&quot;</span>
    105. <span class="sd"> target_loader</span>
    106. <span class="sd"> :param target_path: The path to the target data</span>
    107. <span class="sd"> :return: The loaded target</span>
    108. <span class="sd"> &quot;&quot;&quot;</span>
    109. <span class="n">mat</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">loadmat</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="n">mat_dtype</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">squeeze_me</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    110. <span class="n">struct_as_record</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    111. <span class="n">mask</span> <span class="o">=</span> <span class="n">mat</span><span class="p">[</span><span class="s1">&#39;GTcls&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">Segmentation</span>
    112. <span class="k">return</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span></div></div>
    113. </pre></div>
    114. </div>
    115. </div>
    116. <footer>
    117. <hr/>
    118. <div role="contentinfo">
    119. <p>&#169; Copyright 2021, SuperGradients team.</p>
    120. </div>
    121. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    122. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    123. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    124. </footer>
    125. </div>
    126. </div>
    127. </section>
    128. </div>
    129. <script>
    130. jQuery(function () {
    131. SphinxRtdTheme.Navigation.enable(true);
    132. });
    133. </script>
    134. </body>
    135. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -89,31 +91,71 @@
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     
     
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    +<span class="kn">import</span> <span class="nn">scipy.io</span>
    +<span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
    +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">ConcatDataset</span>
     
     
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    +
    +<span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     <span class="n">PASCAL_VOC_2012_CLASSES</span> <span class="o">=</span> <span class="p">[</span>
     <span class="n">PASCAL_VOC_2012_CLASSES</span> <span class="o">=</span> <span class="p">[</span>
    -    <span class="s1">&#39;background&#39;</span><span class="p">,</span> <span class="s1">&#39;aeroplane&#39;</span><span class="p">,</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">,</span> <span class="s1">&#39;bird&#39;</span><span class="p">,</span> <span class="s1">&#39;boat&#39;</span><span class="p">,</span> <span class="s1">&#39;bottle&#39;</span><span class="p">,</span>
    -    <span class="s1">&#39;bus&#39;</span><span class="p">,</span> <span class="s1">&#39;car&#39;</span><span class="p">,</span> <span class="s1">&#39;cat&#39;</span><span class="p">,</span> <span class="s1">&#39;chair&#39;</span><span class="p">,</span> <span class="s1">&#39;cow&#39;</span><span class="p">,</span> <span class="s1">&#39;diningtable&#39;</span><span class="p">,</span> <span class="s1">&#39;dog&#39;</span><span class="p">,</span> <span class="s1">&#39;horse&#39;</span><span class=
    -    <span class="s1">&#39;motorbike&#39;</span><span class="p">,</span> <span class="s1">&#39;person&#39;</span><span class="p">,</span> <span class="s1">&#39;potted-plant&#39;</span><span class="p">,</span> <span class="s1">&#39;sheep&#39;</span><span class="p">,</span> <span class="s1">&#39;sofa&#39;</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">,</span>
    -    <span class="s1">&#39;tv/monitor&#39;</span><span class="p">,</span> <span class="s1">&#39;ambigious&#39;</span>
    +    <span class="s2">&quot;background&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;aeroplane&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;bicycle&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;bird&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;boat&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;bottle&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;bus&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;car&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;cat&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;chair&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;cow&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;diningtable&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;dog&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;horse&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;motorbike&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;person&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;potted-plant&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;sheep&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;sofa&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;train&quot;</span><span class="p">,</span>
    +    <span class="s2">&quot;tv/monitor&quot;</span><span class="p">,</span>
     <span class="p">]</span>
     <span class="p">]</span>
     
     
     
     
    -<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
    +<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    PascalVOC2012SegmentationDataSet - Segmentation Data Set Class for Pascal VOC 2012 Data Set</span>
     <span class="sd">    PascalVOC2012SegmentationDataSet - Segmentation Data Set Class for Pascal VOC 2012 Data Set</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
    +    <span class="n">IGNORE_LABEL</span> <span class="o">=</span> <span class="mi">21</span>
    +    <span class="n">_ORIGINAL_IGNORE_LABEL</span> <span class="o">=</span> <span class="mi">255</span>
    +
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><sp
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><sp
    -        <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s1">&#39;.jpg&#39;</span> <span class="k">if</span> <span class="n">sample_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">sample_suffix</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s1">&#39;.png&#39;</span> <span class="k">if</span> <span class="n">target_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">target_suffix</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s2">&quot;.jpg&quot;</span> <span class="k">if</span> <span class="n">sample_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">sample_suffix</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s2">&quot;.png&quot;</span> <span class="k">if</span> <span class="n">target_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">target_suffix</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
     
     
    -        <span class="c1"># THERE ARE 21 CLASSES, AND BACKGROUND</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES</span>
     
     
    -<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet.decode_segmentation_mask"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet.decode_segmentation_mask">[docs]</a>    <span class="k">def</span> <span class="nf">decode_segmentation_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">label_mask</span><span cla
    +<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet.target_transform">[docs]</a>    <span class="nd">@staticmethod</span>
    +    <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        target_transform - Transforms the label mask</span>
    +<span class="sd">        This function overrides the original function from SegmentationDataSet and changes target pixels with value</span>
    +<span class="sd">        255 to value = IGNORE_LABEL. This was done since current IoU metric from torchmetrics does not</span>
    +<span class="sd">        support such a high ignore label value (crashed on OOM)</span>
    +
    +<span class="sd">            :param target: The target mask to transform</span>
    +<span class="sd">            :return:       The transformed target mask</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">out</span> <span class="o">=</span> <span class="n">SegmentationDataSet</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
    +        <span class="n">out</span><span class="p">[</span><span class="n">out</span> <span class="o">==</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="o">.</span><span class="n">_ORIGINAL_IGNORE_LABEL</span><span class="p">]</span> <span class="o">=</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="o">.</span><span class="n">IGNORE_LABEL</span>
    +        <span class="k">return</span> <span class="n">out</span></div>
    +
    +<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet.decode_segmentation_mask"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet.decode_segmentation_mask">[docs]</a>    <span class="k">def</span> <span class="nf">decode_segmentation_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">label_mask</span><span class="p">:</span> <span class="n"
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        decode_segmentation_mask - Decodes the colors for the Segmentation Mask</span>
     <span class="sd">        decode_segmentation_mask - Decodes the colors for the Segmentation Mask</span>
     <span class="sd">            :param: label_mask:  an (M,N) array of integer values denoting</span>
     <span class="sd">            :param: label_mask:  an (M,N) array of integer values denoting</span>
    @@ -125,8 +167,7 @@
             <span class="n">g</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
             <span class="n">g</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
             <span class="n">b</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
             <span class="n">b</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
     
     
    -        <span class="c1"># REMOVING THE BACKGROUND CLASS FROM THE PLOTS</span>
    -        <span class="n">num_classes_to_plot</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classes</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
    +        <span class="n">num_classes_to_plot</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classes</span><span class="p">)</span>
             <span class="k">for</span> <span class="n">ll</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_classes_to_plot</span><span class="p">):</span>
             <span class="k">for</span> <span class="n">ll</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_classes_to_plot</span><span class="p">):</span>
                 <span class="n">r</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
                 <span class="n">r</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
                 <span class="n">g</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
                 <span class="n">g</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
    @@ -145,8 +186,8 @@
             <span class="c1"># GENERATE SAMPLES AND TARGETS HERE SPECIFICALLY FOR PASCAL VOC 2012</span>
             <span class="c1"># GENERATE SAMPLES AND TARGETS HERE SPECIFICALLY FOR PASCAL VOC 2012</span>
             <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><sp
             <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><sp
                 <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">lines</span><span class="p">:</span>
                 <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">lines</span><span class="p">:</span>
    -                <span class="n">image_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</sp
    -                <span class="n">mask_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</spa
    +                <span class="n">image_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</sp
    +                <span class="n">mask_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</spa
     
     
                     <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">mask_path</span><span class="p">)</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">image_path</span><span class="p">):</span>
                     <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">mask_path</span><span class="p">)</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">image_path</span><span class="p">):</span>
                         <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_path</span><span class="p">))</span>
                         <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_path</span><span class="p">))</span>
    @@ -184,6 +225,59 @@
                     <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span>
                     <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span>
                 <span class="p">]</span>
                 <span class="p">]</span>
             <span class="p">)</span></div>
             <span class="p">)</span></div>
    +
    +
    +<div class="viewcode-block" id="PascalAUG2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalAUG2012SegmentationDataSet</span><span class="p">(</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">):</span>
    +    <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">    PascalAUG2012SegmentationDataSet - Segmentation Data Set Class for Pascal AUG 2012 Data Set</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s2">&quot;.jpg&quot;</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s2">&quot;.mat&quot;</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">sample_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span><span class="p">,</span> <span class
    +
    +<div class="viewcode-block" id="PascalAUG2012SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet.target_loader">[docs]</a>    <span class="nd">@staticmethod</span>
    +    <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        target_loader</span>
    +<span class="sd">            :param target_path: The path to the target data</span>
    +<span class="sd">            :return:            The loaded target</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">mat</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">loadmat</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="n">mat_dtype</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">squeeze_me</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span 
    +        <span class="n">mask</span> <span class="o">=</span> <span class="n">mat</span><span class="p">[</span><span class="s2">&quot;GTcls&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">Segmentation</span>
    +        <span class="k">return</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span></div></div>
    +
    +
    +<div class="viewcode-block" id="PascalVOCAndAUGUnifiedDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOCAndAUGUnifiedDataset">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCAndAUGUnifiedDataset</span><span class="p">(</span><span class="n">ConcatDataset</span><span class="p">):</span>
    +    <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">    Pascal VOC + AUG train dataset, aka `SBD` dataset contributed in &quot;Semantic contours from inverse detectors&quot;.</span>
    +<span class="sd">    This is class implement the common usage of the SBD and PascalVOC datasets as a unified augmented trainset.</span>
    +<span class="sd">    The unified dataset includes a total of 10,582 samples and don&#39;t contains duplicate samples from the PascalVOC</span>
    +<span class="sd">    validation set.</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    +        <span class="nb">print</span><span class="p">(</span><span class="n">kwargs</span><span class="p">)</span>
    +        <span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;list_file&quot;</span><span class="p">),</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;samples_sub_directory&quot;</span><span class="p">),</span> <span class="n">kwargs</span><span class="o">.</span>
    +            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
    +                <span class="s2">&quot;[list_file, samples_sub_directory, targets_sub_directory] arguments passed will not be used&quot;</span>
    +                <span class="s2">&quot; when passed to `PascalVOCAndAUGUnifiedDataset`. Those values are predefined for initiating&quot;</span>
    +                <span class="s2">&quot; the Pascal VOC + AUG training set.&quot;</span>
    +            <span class="p">)</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    +            <span class="n">datasets</span><span class="o">=</span><span class="p">[</span>
    +                <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">(</span>
    +                    <span class="n">list_file</span><span class="o">=</span><span class="s2">&quot;VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt&quot;</span><span class="p">,</span>
    +                    <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCdevkit/VOC2012/JPEGImages&quot;</span><span class="p">,</span>
    +                    <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCdevkit/VOC2012/SegmentationClass&quot;</span><span class="p">,</span>
    +                    <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
    +                <span class="p">),</span>
    +                <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">(</span>
    +                    <span class="n">list_file</span><span class="o">=</span><span class="s2">&quot;VOCaug/dataset/aug.txt&quot;</span><span class="p">,</span> <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCaug/dataset/img&quot;</span><span class="p">,</span> <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCaug/dataset/cls&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwarg
    +                <span class="p">),</span>
    +            <span class="p">]</span>
    +        <span class="p">)</span></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -213,4 +307,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.segmentation_datasets.segmentation_dataset &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.segmentation_datasets.segmentation_dataset &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -87,71 +89,44 @@
                  
                  
       <h1>Source code for super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</h1><div class="highlight"><pre>
       <h1>Source code for super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</h1><div class="highlight"><pre>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
    -<span class="kn">import</span> <span class="nn">torch</span>
    -<span class="kn">import</span> <span class="nn">random</span>
    +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Iterable</span>
    +
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    -<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
    -<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span>
    +<span class="kn">import</span> <span class="nn">torch</span>
     <span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transform</span>
     <span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transform</span>
     <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
     <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
    +<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     
     
     <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.sg_dataset</span> <span class="kn">import</span> <span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.sg_dataset</span> <span class="kn">import</span> <span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">RandomFlip</span><span class="p">,</span> <span class="n">Rescale</span><span class="p">,</span> <span class="n">RandomRescale</span><span class="p">,</span> <span class="n">RandomRotate</span><span class="p">,</span> \
    -    <span class="n">CropImageAndMask</span><span class="p">,</span> <span class="n">RandomGaussianBlur</span><span class="p">,</span> <span class="n">PadShortToCropSize</span>
     
     
     
     
    -<div class="viewcode-block" id="SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">SegmentationDataSet</span><span class="p">(</span><span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span><span class="p">):</span>
    +<div class="viewcode-block" id="SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">SegmentationDataSet</span><span class="p">(</span><span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span><span class="p">):</span>
     
     
    -    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;image_mask_transforms&#39;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
    -    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;image_mask_transforms_aug&#39;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
    +    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;transforms&#39;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">list_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">samples_sub_directory</span><span class="p">:</span> <span class="nb">str</s
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">list_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">samples_sub_directory</span><span class="p">:</span> <span class="nb">str</s
                      <span class="n">targets_sub_directory</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
                      <span class="n">targets_sub_directory</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">img_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">608</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">,</spa
    -                 <span class="n">dataset_hyper_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sample_loader</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span>
    -                 <span class="n">target_loader</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">collate_fn</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">target_extension</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;.png
    -                 <span class="n">image_mask_transforms</span><span class="p">:</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">image_mask_transforms_aug</span><span class="p">:</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +                 <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +                 <span class="n">collate_fn</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">target_extension</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;.png&#39;</span><span class="p">,</span>
    +                 <span class="n">transforms</span><span class="p">:</span> <span class="n">Iterable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        SegmentationDataSet</span>
     <span class="sd">        SegmentationDataSet</span>
    -<span class="sd">                                * Please use self.augment == True only for training</span>
    -
     <span class="sd">            :param root:                        Root folder of the Data Set</span>
     <span class="sd">            :param root:                        Root folder of the Data Set</span>
     <span class="sd">            :param list_file:                   Path to the file with the samples list</span>
     <span class="sd">            :param list_file:                   Path to the file with the samples list</span>
     <span class="sd">            :param samples_sub_directory:       name of the samples sub-directory</span>
     <span class="sd">            :param samples_sub_directory:       name of the samples sub-directory</span>
     <span class="sd">            :param targets_sub_directory:       name of the targets sub-directory</span>
     <span class="sd">            :param targets_sub_directory:       name of the targets sub-directory</span>
    -<span class="sd">            :param img_size:                    Image size of the Model that uses this Data Set</span>
    -<span class="sd">            :param crop_size:                   The size of the cropped image</span>
    -<span class="sd">            :param batch_size:                  Batch Size of the Model that uses this Data Set</span>
    -<span class="sd">            :param augment:                     True / False flag to allow Augmentation</span>
    -<span class="sd">            :param dataset_hyper_params:        Any hyper params required for the data set</span>
     <span class="sd">            :param cache_labels:                &quot;Caches&quot; the labels -&gt; Pre-Loads to memory as a list</span>
     <span class="sd">            :param cache_labels:                &quot;Caches&quot; the labels -&gt; Pre-Loads to memory as a list</span>
     <span class="sd">            :param cache_images:                &quot;Caches&quot; the images -&gt; Pre-Loads to memory as a list</span>
     <span class="sd">            :param cache_images:                &quot;Caches&quot; the images -&gt; Pre-Loads to memory as a list</span>
    -<span class="sd">            :param sample_loader:               A function that specifies how to load a sample</span>
    -<span class="sd">            :param target_loader:               A function that specifies how to load a target</span>
     <span class="sd">            :param collate_fn:                  collate_fn func to process batches for the Data Loader</span>
     <span class="sd">            :param collate_fn:                  collate_fn func to process batches for the Data Loader</span>
    -<span class="sd">            :param target_extension:            file extension of the targets (defualt is .png for PASCAL VOC 2012)</span>
    -<span class="sd">            :param image_mask_transforms        transforms to be applied on image and mask when augment=False</span>
    -<span class="sd">            :param image_mask_transforms_aug    transforms to be applied on image and mask when augment=True</span>
    +<span class="sd">            :param target_extension:            file extension of the targets (default is .png for PASCAL VOC 2012)</span>
    +<span class="sd">            :param transforms:                  transforms to be applied on image and mask</span>
    +
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span> <span class="o">=</span> <span class="n">samples_sub_directory</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span> <span class="o">=</span> <span class="n">samples_sub_directory</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span> <span class="o">=</span> <span class="n">targets_sub_directory</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span> <span class="o">=</span> <span class="n">targets_sub_directory</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_hyperparams</span> <span class="o">=</span> <span class="n">dataset_hyper_params</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">cache_labels</span> <span class="o">=</span> <span class="n">cache_labels</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">cache_labels</span> <span class="o">=</span> <span class="n">cache_labels</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">cache_images</span> <span class="o">=</span> <span class="n">cache_images</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">cache_images</span> <span class="o">=</span> <span class="n">cache_images</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">img_size</span> <span class="o">=</span> <span class="n">img_size</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="n">crop_size</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">augment</span> <span class="o">=</span> <span class="n">augment</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">batch_index</span> <span class="o">=</span> <span class="kc">None</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">total_batches_num</span> <span class="o">=</span> <span class="kc">None</span>
    -
    -        <span class="c1"># ENABLES USING CUSTOM SAMPLE/TARGET LOADERS</span>
    -        <span class="k">if</span> <span class="n">sample_loader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">sample_loader</span> <span class="o">=</span> <span class="n">sample_loader</span>
    -        <span class="k">if</span> <span class="n">target_loader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">target_loader</span> <span class="o">=</span> <span class="n">target_loader</span>
     
     
             <span class="c1"># CREATE A DIRECTORY DATASET OR A LIST DATASET BASED ON THE list_file INPUT VARIABLE</span>
             <span class="c1"># CREATE A DIRECTORY DATASET OR A LIST DATASET BASED ON THE list_file INPUT VARIABLE</span>
             <span class="k">if</span> <span class="n">list_file</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">list_file</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    @@ -165,26 +140,8 @@
                                           <span class="n">sample_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_loader</span><span class="p">,</span> <span class="n">sample_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">,</span>
                                           <span class="n">sample_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_loader</span><span class="p">,</span> <span class="n">sample_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">,</span>
                                           <span class="n">target_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_loader</span><span class="p">,</span> <span class="n">target_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">,</span>
                                           <span class="n">target_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_loader</span><span class="p">,</span> <span class="n">target_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">,</span>
                                           <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
                                           <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
    -        <span class="c1"># DEFAULT TRANSFORMS</span>
    -        <span class="c1"># FIXME - Rescale before RandomRescale is kept for legacy support, consider removing it like most implementation</span>
    -        <span class="c1">#  papers regimes.</span>
    -        <span class="n">default_image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">RandomFlip</span><span class="p">(),</span>
    -                                                               <span class="n">Rescale</span><span class="p">(</span><span class="n">short_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">img_size</span><span class="p">),</span>
    -                                                               <span class="n">RandomRescale</span><span class="p">(</span><span class="n">scales</span><span class="o">=</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)),</span>
    -                                                               <span class="n">RandomRotate</span><span class="p">(),</span>
    -                                                               <span class="n">PadShortToCropSize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">),</span>
    -                                                               <span class="n">CropImageAndMask</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span>
    -                                                                                <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;random&quot;</span><span class="p">),</span>
    -                                                               <span class="n">RandomGaussianBlur</span><span class="p">()])</span>
    -
    -        <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">image_mask_transforms_aug</span> <span class="ow">or</span> <span class="n">default_image_mask_transforms_aug</span>
    -        <span class="c1"># FIXME: CROP SIZE CANNOT BE PASSED WHEN LIST</span>
    -        <span class="k">if</span> <span class="n">image_mask_transforms</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">Rescale</span><span class="p">(</span><span class="n">short_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">),</span>
    -                                                       <span class="n">CropImageAndMask</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;center&quot;</span><span class="p">)</span>
    -                                                       <span class="p">])</span>
    -
    -        <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">image_mask_transforms</span>
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span> <span class="k">if</span> <span class="n">transforms</span> <span class="k">else</span> <span class="p">[])</span>
     
     
         <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
             <span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
             <span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    @@ -206,7 +163,7 @@
     
     
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">(</span><span class="n">sample</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">(</span><span class="n">sample</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="SegmentationDataSet.sample_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.sample_loader">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="SegmentationDataSet.sample_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.sample_loader">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">sample_loader</span><span class="p">(</span><span class="n">sample_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
         <span class="k">def</span> <span class="nf">sample_loader</span><span class="p">(</span><span class="n">sample_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        sample_loader - Loads a dataset image from path using PIL</span>
     <span class="sd">        sample_loader - Loads a dataset image from path using PIL</span>
    @@ -216,7 +173,7 @@
             <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">sample_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">&#39;RGB&#39;</span><span class="p">)</span>
             <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">sample_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">&#39;RGB&#39;</span><span class="p">)</span>
             <span class="k">return</span> <span class="n">image</span></div>
             <span class="k">return</span> <span class="n">image</span></div>
     
     
    -<div class="viewcode-block" id="SegmentationDataSet.sample_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.sample_transform">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="SegmentationDataSet.sample_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.sample_transform">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        sample_transform - Transforms the sample image</span>
     <span class="sd">        sample_transform - Transforms the sample image</span>
    @@ -230,7 +187,7 @@
     
     
             <span class="k">return</span> <span class="n">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">)</span></div>
             <span class="k">return</span> <span class="n">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">)</span></div>
     
     
    -<div class="viewcode-block" id="SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.target_loader">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.target_loader">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
         <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        target_loader</span>
     <span class="sd">        target_loader</span>
    @@ -240,7 +197,7 @@
             <span class="n">target</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">)</span>
             <span class="n">target</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">)</span>
             <span class="k">return</span> <span class="n">target</span></div>
             <span class="k">return</span> <span class="n">target</span></div>
     
     
    -<div class="viewcode-block" id="SegmentationDataSet.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.target_transform">[docs]</a>    <span class="nd">@staticmethod</span>
    +<div class="viewcode-block" id="SegmentationDataSet.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.target_transform">[docs]</a>    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        target_transform - Transforms the sample image</span>
     <span class="sd">        target_transform - Transforms the sample image</span>
    @@ -258,9 +215,6 @@
             <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">:</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">:</span>
                 <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span>
                 <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">batch_index</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.<
    -        <span class="bp">self</span><span class="o">.</span><span class="n">total_batches_num</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_index</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span>
    -
             <span class="c1"># EXTRACT THE LABELS FROM THE TUPLES LIST</span>
             <span class="c1"># EXTRACT THE LABELS FROM THE TUPLES LIST</span>
             <span class="n">image_files</span><span class="p">,</span> <span class="n">label_files</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">))</span>
             <span class="n">image_files</span><span class="p">,</span> <span class="n">label_files</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">))</span>
             <span class="n">image_indices_to_remove</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">image_indices_to_remove</span> <span class="o">=</span> <span class="p">[]</span>
    @@ -309,66 +263,13 @@
                 <span class="bp">self</span><span class="o">.</span><span class="n">label_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">label_files</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <sp
                 <span class="bp">self</span><span class="o">.</span><span class="n">label_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">label_files</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <sp
                 <span class="bp">self</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">)</span> <span class="k">if</span> <span class="n
                 <span class="bp">self</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">)</span> <span class="k">if</span> <span class="n
     
     
    -    <span class="k">def</span> <span class="nf">_calculate_short_size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
    -        <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        _calculate_crop</span>
    -<span class="sd">        :param img:</span>
    -<span class="sd">        :return:</span>
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">augment</span><span class="p">:</span>
    -            <span class="c1"># RANDOM SCALE (SHORT EDGE FROM 480 TO 720)</span>
    -            <span class="n">short_size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_size</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span
    -        <span class="k">else</span><span class="p">:</span>
    -            <span class="n">short_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span>
    -
    -        <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">size</span>
    -        <span class="k">if</span> <span class="n">w</span> <span class="o">&gt;</span> <span class="n">h</span><span class="p">:</span>
    -            <span class="n">oh</span> <span class="o">=</span> <span class="n">short_size</span>
    -            <span class="n">ow</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">w</span> <span class="o">*</span> <span class="n">oh</span> <span class="o">/</span> <span class="n">h</span><span class="p">)</span>
    -        <span class="k">else</span><span class="p">:</span>
    -            <span class="n">ow</span> <span class="o">=</span> <span class="n">short_size</span>
    -            <span class="n">oh</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">h</span> <span class="o">*</span> <span class="n">ow</span> <span class="o">/</span> <span class="n">w</span><span class="p">)</span>
    -
    -        <span class="k">return</span> <span class="n">oh</span><span class="p">,</span> <span class="n">ow</span><span class="p">,</span> <span class="n">short_size</span>
    -
    -    <span class="k">def</span> <span class="nf">_get_center_crop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">):</span>
    -        <span class="sd">&quot;&quot;&quot;</span>
    -
    -<span class="sd">        :param w:</span>
    -<span class="sd">        :param h:</span>
    -<span class="sd">        :return:</span>
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="c1"># CENTER CROP</span>
    -        <span class="n">x1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.</span><span class="p">))</span>
    -        <span class="n">y1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.</span><span class="p">))</span>
    -
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">augment</span><span class="p">:</span>
    -            <span class="c1"># RANDOM CROP CROP_SIZE</span>
    -            <span class="n">x1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
    -            <span class="n">y1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
    -
    -        <span class="k">return</span> <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span>
    -
         <span class="k">def</span> <span class="nf">_transform_image_and_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">:</span>
         <span class="k">def</span> <span class="nf">_transform_image_and_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        _transform -  Transforms the input (image, mask) in the following order:</span>
    -<span class="sd">                                1. FLIP (if augment==true)</span>
    -<span class="sd">                                2. RESIZE</span>
    -<span class="sd">                                3. ROTATE (if augment==true)</span>
    -<span class="sd">                                4. CROP</span>
    -<span class="sd">                                5. GAUSSIAN BLUR (if augment==true)</span>
    -
    -<span class="sd">                            * Please use self.augment == True only for training</span>
    -
     <span class="sd">            :param image:           The input image</span>
     <span class="sd">            :param image:           The input image</span>
     <span class="sd">            :param mask:            The input mask</span>
     <span class="sd">            :param mask:            The input mask</span>
     <span class="sd">            :return:                The transformed image, mask</span>
     <span class="sd">            :return:                The transformed image, mask</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">augment</span><span class="p">:</span>
    -            <span class="n">transformed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms_aug</span><span class="p">({</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">&quot;mask&quot;</span><span class="p">:</span> <span class="n">mask</span><span class="p">})</span>
    -        <span class="k">else</span><span class="p">:</span>
    -            <span class="n">transformed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms</span><span class="p">({</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">&quot;mask&quot;</span><span class="p">:</span> <span class="n">mask</span><span class="p">})</span>
    -
    +        <span class="n">transformed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span><span class="p">({</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">&quot;mask&quot;</span><span class="p">:</span> <span class="n">mask</span><span class="p">})</span>
             <span class="k">return</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span></div>
             <span class="k">return</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span></div>
     </pre></div>
     </pre></div>
     
     
    @@ -399,4 +300,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/jquery.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
             <script src="../../../../../_static/underscore.js"></script>
    +        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
             <script src="../../../../../_static/doctools.js"></script>
    +        <script src="../../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <script src="../../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -92,7 +94,7 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
     
     
     
     
    -<div class="viewcode-block" id="SuperviselyPersonsDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.SuperviselyPersonsDataset">[docs]</a><span class="k">class</span> <span class="nc">SuperviselyPersonsDataset</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
    +<div class="viewcode-block" id="SuperviselyPersonsDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SuperviselyPersonsDataset">[docs]</a><span class="k">class</span> <span class="nc">SuperviselyPersonsDataset</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set,</span>
     <span class="sd">    SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set,</span>
     <span class="sd">    main resolution of dataset: (600 x 800).</span>
     <span class="sd">    main resolution of dataset: (600 x 800).</span>
    @@ -126,7 +128,8 @@
                         <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span><span class="p">))</span>
                         <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span><span class="p">))</span>
                     <span class="k">else</span><span class="p">:</span>
                     <span class="k">else</span><span class="p">:</span>
                         <span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Sample and/or target file(s) not found or in illegal format &quot;</span>
                         <span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Sample and/or target file(s) not found or in illegal format &quot;</span>
    -                                         <span class="sa">f</span><span class="s2">&quot;(sample path: </span><span class="si">{</span><span class="n">sample_path</span><span class="si">}</span><span class="s2">, target path: </span><span class="si">{</span><span class="n">target_path</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">)</span></div>
    +                                         <span class="sa">f</span><span class="s2">&quot;(sample path: </span><span class="si">{</span><span class="n">sample_path</span><span class="si">}</span><span class="s2">, target path: </span><span class="si">{</span><span class="n">target_path</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">)</span>
    +        <span class="nb">super</span><span class="p">(</span><span class="n">SuperviselyPersonsDataset</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -156,4 +159,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.datasets.sg_dataset &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.datasets.sg_dataset &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -97,7 +99,7 @@
     <span class="n">IMG_EXTENSIONS</span> <span class="o">=</span> <span class="p">(</span><span class="s1">&#39;.jpg&#39;</span><span class="p">,</span> <span class="s1">&#39;.jpeg&#39;</span><span class="p">,</span> <span class="s1">&#39;.png&#39;</span><span class="p">,</span> <span class="s1">&#39;.ppm&#39;</span><span class="p">,</span> <span class="s1">&#39;.bmp&#39;</span><span class="p">,</span> <span class="s1">&#39;.pgm&#39;</span><span class="p">,</span> <span class="s1">&#39;.tif&#39;</
     <span class="n">IMG_EXTENSIONS</span> <span class="o">=</span> <span class="p">(</span><span class="s1">&#39;.jpg&#39;</span><span class="p">,</span> <span class="s1">&#39;.jpeg&#39;</span><span class="p">,</span> <span class="s1">&#39;.png&#39;</span><span class="p">,</span> <span class="s1">&#39;.ppm&#39;</span><span class="p">,</span> <span class="s1">&#39;.bmp&#39;</span><span class="p">,</span> <span class="s1">&#39;.pgm&#39;</span><span class="p">,</span> <span class="s1">&#39;.tif&#39;</
     
     
     
     
    -<div class="viewcode-block" id="BaseSgVisionDataset"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.BaseSgVisionDataset">[docs]</a><span class="k">class</span> <span class="nc">BaseSgVisionDataset</span><span class="p">(</span><span class="n">VisionDataset</span><span class="p">):</span>
    +<span class="k">class</span> <span class="nc">BaseSgVisionDataset</span><span class="p">(</span><span class="n">VisionDataset</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    BaseSgVisionDataset</span>
     <span class="sd">    BaseSgVisionDataset</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    @@ -160,16 +162,16 @@
     
     
             <span class="k">return</span> <span class="kc">False</span>
             <span class="k">return</span> <span class="kc">False</span>
     
     
    -<div class="viewcode-block" id="BaseSgVisionDataset.numpy_loader_func"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.BaseSgVisionDataset.numpy_loader_func">[docs]</a>    <span class="nd">@staticmethod</span>
    +    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">numpy_loader_func</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">numpy_loader_func</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        _numpy_loader_func - Uses numpy load func</span>
     <span class="sd">        _numpy_loader_func - Uses numpy load func</span>
     <span class="sd">            :param path:</span>
     <span class="sd">            :param path:</span>
     <span class="sd">            :return:</span>
     <span class="sd">            :return:</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span></div>
    +        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="BaseSgVisionDataset.text_file_loader_func"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.BaseSgVisionDataset.text_file_loader_func">[docs]</a>    <span class="nd">@staticmethod</span>
    +    <span class="nd">@staticmethod</span>
         <span class="k">def</span> <span class="nf">text_file_loader_func</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inline_splitter</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39; &#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
         <span class="k">def</span> <span class="nf">text_file_loader_func</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inline_splitter</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39; &#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        text_file_loader_func - Uses a line by line based code to get vectorized data from a text-based file</span>
     <span class="sd">        text_file_loader_func - Uses a line by line based code to get vectorized data from a text-based file</span>
    @@ -184,10 +186,10 @@
             <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">text_file</span><span class="p">:</span>
             <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">text_file</span><span class="p">:</span>
                 <span class="n">targets_list</span> <span class="o">=</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">inline_splitter</span><span class="p">)))</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</
                 <span class="n">targets_list</span> <span class="o">=</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">inline_splitter</span><span class="p">)))</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</
     
     
    -        <span class="k">return</span> <span class="n">targets_list</span></div></div>
    +        <span class="k">return</span> <span class="n">targets_list</span>
     
     
     
     
    -<div class="viewcode-block" id="DirectoryDataSet"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.DirectoryDataSet">[docs]</a><span class="k">class</span> <span class="nc">DirectoryDataSet</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
    +<div class="viewcode-block" id="DirectoryDataSet"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.datasets.DirectoryDataSet">[docs]</a><span class="k">class</span> <span class="nc">DirectoryDataSet</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    DirectoryDataSet - A PyTorch Vision Data Set extension that receives a root Dir and two separate sub directories:</span>
     <span class="sd">    DirectoryDataSet - A PyTorch Vision Data Set extension that receives a root Dir and two separate sub directories:</span>
     <span class="sd">                        - Sub-Directory for Samples</span>
     <span class="sd">                        - Sub-Directory for Samples</span>
    @@ -279,7 +281,7 @@
                     <span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s1">&#39; There are &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">missing_files_counter</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39; missing  &#39;</span> <span class="o">+</span> <span class="n">counter_name</span><span class="p">)</span></div>
                     <span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s1">&#39; There are &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">missing_files_counter</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39; missing  &#39;</span> <span class="o">+</span> <span class="n">counter_name</span><span class="p">)</span></div>
     
     
     
     
    -<div class="viewcode-block" id="ListDataset"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.ListDataset">[docs]</a><span class="k">class</span> <span class="nc">ListDataset</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
    +<div class="viewcode-block" id="ListDataset"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.datasets.ListDataset">[docs]</a><span class="k">class</span> <span class="nc">ListDataset</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    ListDataset - A PyTorch Vision Data Set extension that receives a file with FULL PATH to each of the samples.</span>
     <span class="sd">    ListDataset - A PyTorch Vision Data Set extension that receives a file with FULL PATH to each of the samples.</span>
     <span class="sd">                  Then, the assumption is that for every sample, there is a * matching target * in the same</span>
     <span class="sd">                  Then, the assumption is that for every sample, there is a * matching target * in the same</span>
    @@ -383,4 +385,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.exceptions.dataset_exceptions &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.exceptions.dataset_exceptions</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.exceptions.dataset_exceptions</h1><div class="highlight"><pre>
    83. <span></span>
    84. <div class="viewcode-block" id="IllegalDatasetParameterException"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.dataset_exceptions.IllegalDatasetParameterException">[docs]</a><span class="k">class</span> <span class="nc">IllegalDatasetParameterException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    85. <span class="sd">&quot;&quot;&quot;</span>
    86. <span class="sd"> Exception raised illegal dataset param.</span>
    87. <span class="sd"> Attributes:</span>
    88. <span class="sd"> message -- explanation of the error</span>
    89. <span class="sd"> &quot;&quot;&quot;</span>
    90. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
    91. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Unsupported dataset parameter format: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
    92. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
    93. <div class="viewcode-block" id="EmptyDatasetException"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.dataset_exceptions.EmptyDatasetException">[docs]</a><span class="k">class</span> <span class="nc">EmptyDatasetException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    94. <span class="sd">&quot;&quot;&quot;</span>
    95. <span class="sd"> Exception raised when a dataset does not have any image for a specific config</span>
    96. <span class="sd"> Attributes:</span>
    97. <span class="sd"> message -- explanation of the error</span>
    98. <span class="sd"> &quot;&quot;&quot;</span>
    99. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
    100. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Empty Dataset: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
    101. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
    102. <div class="viewcode-block" id="UnsupportedBatchItemsFormat"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.dataset_exceptions.UnsupportedBatchItemsFormat">[docs]</a><span class="k">class</span> <span class="nc">UnsupportedBatchItemsFormat</span><span class="p">(</span><span class="ne">ValueError</span><span class="p">):</span>
    103. <span class="sd">&quot;&quot;&quot;Exception raised illegal batch items returned from data loader.</span>
    104. <span class="sd"> Attributes:</span>
    105. <span class="sd"> message -- explanation of the error</span>
    106. <span class="sd"> &quot;&quot;&quot;</span>
    107. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    108. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Batch items returned by the data loader expected format: </span><span class="se">\n</span><span class="s2">&quot;</span> \
    109. <span class="s2">&quot;1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(&quot;</span> \
    110. <span class="s2">&quot;batch_items) = 2 </span><span class="se">\n</span><span class="s2">&quot;</span> \
    111. <span class="s2">&quot;2. tuple: (inputs, targets, additional_batch_items)&quot;</span>
    112. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
    113. </pre></div>
    114. </div>
    115. </div>
    116. <footer>
    117. <hr/>
    118. <div role="contentinfo">
    119. <p>&#169; Copyright 2021, SuperGradients team.</p>
    120. </div>
    121. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    122. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    123. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    124. </footer>
    125. </div>
    126. </div>
    127. </section>
    128. </div>
    129. <script>
    130. jQuery(function () {
    131. SphinxRtdTheme.Navigation.enable(true);
    132. });
    133. </script>
    134. </body>
    135. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.exceptions.sg_model_exceptions &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.exceptions.sg_model_exceptions</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.exceptions.sg_model_exceptions</h1><div class="highlight"><pre>
    83. <span></span>
    84. <div class="viewcode-block" id="UnsupportedTrainingParameterFormat"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.sg_model_exceptions.UnsupportedTrainingParameterFormat">[docs]</a><span class="k">class</span> <span class="nc">UnsupportedTrainingParameterFormat</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    85. <span class="sd">&quot;&quot;&quot;Exception raised illegal training param format.</span>
    86. <span class="sd"> Attributes:</span>
    87. <span class="sd"> message -- explanation of the error</span>
    88. <span class="sd"> &quot;&quot;&quot;</span>
    89. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
    90. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Unsupported training parameter format: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
    91. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
    92. <div class="viewcode-block" id="UnsupportedOptimizerFormat"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.sg_model_exceptions.UnsupportedOptimizerFormat">[docs]</a><span class="k">class</span> <span class="nc">UnsupportedOptimizerFormat</span><span class="p">(</span><span class="n">UnsupportedTrainingParameterFormat</span><span class="p">):</span>
    93. <span class="sd">&quot;&quot;&quot;Exception raised illegal optimizer format.</span>
    94. <span class="sd"> Attributes:</span>
    95. <span class="sd"> message -- explanation of the error</span>
    96. <span class="sd"> &quot;&quot;&quot;</span>
    97. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    98. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    99. <span class="s2">&quot;optimizer parameter expected one of [&#39;Adam&#39;,&#39;SGD&#39;,&#39;RMSProp&#39;], or torch.optim.Optimizer object&quot;</span><span class="p">)</span></div>
    100. <div class="viewcode-block" id="IllegalDataloaderInitialization"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.sg_model_exceptions.IllegalDataloaderInitialization">[docs]</a><span class="k">class</span> <span class="nc">IllegalDataloaderInitialization</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    101. <span class="sd">&quot;&quot;&quot;Exception raised illegal data loaders.</span>
    102. <span class="sd"> &quot;&quot;&quot;</span>
    103. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    104. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    105. <span class="s2">&quot;train_loader, valid_loader and class parameters are required when initializing SgModel with data loaders&quot;</span><span class="p">)</span></div>
    106. </pre></div>
    107. </div>
    108. </div>
    109. <footer>
    110. <hr/>
    111. <div role="contentinfo">
    112. <p>&#169; Copyright 2021, SuperGradients team.</p>
    113. </div>
    114. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    115. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    116. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    117. </footer>
    118. </div>
    119. </div>
    120. </section>
    121. </div>
    122. <script>
    123. jQuery(function () {
    124. SphinxRtdTheme.Navigation.enable(true);
    125. });
    126. </script>
    127. </body>
    128. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.kd_model.kd_model &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.kd_trainer.kd_trainer &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -76,7 +78,7 @@
       <ul class="wy-breadcrumbs">
       <ul class="wy-breadcrumbs">
           <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
           <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
               <li><a href="../../../index.html">Module code</a> &raquo;</li>
               <li><a href="../../../index.html">Module code</a> &raquo;</li>
    -      <li>super_gradients.training.kd_model.kd_model</li>
    +      <li>super_gradients.training.kd_trainer.kd_trainer</li>
           <li class="wy-breadcrumbs-aside">
           <li class="wy-breadcrumbs-aside">
           </li>
           </li>
       </ul>
       </ul>
    @@ -85,101 +87,85 @@
               <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
               <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
                <div itemprop="articleBody">
                <div itemprop="articleBody">
                  
                  
    -  <h1>Source code for super_gradients.training.kd_model.kd_model</h1><div class="highlight"><pre>
    -<span></span><span class="kn">import</span> <span class="nn">torch.nn</span>
    -
    +  <h1>Source code for super_gradients.training.kd_trainer.kd_trainer</h1><div class="highlight"><pre>
    +<span></span><span class="kn">import</span> <span class="nn">hydra</span>
    +<span class="kn">import</span> <span class="nn">torch.nn</span>
    +<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">DictConfig</span>
    +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
    +
    +<span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">MultiGPUMode</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.dataloaders</span> <span class="kn">import</span> <span class="n">dataloaders</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">KD_ARCHITECTURES</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">KD_ARCHITECTURES</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models.kd_modules.kd_module</span> <span class="kn">import</span> <span class="n">KDModule</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models.kd_modules.kd_module</span> <span class="kn">import</span> <span class="n">KDModule</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.sg_model</span> <span class="kn">import</span> <span class="n">SgModel</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.sg_trainer</span> <span class="kn">import</span> <span class="n">Trainer</span>
     <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
     <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span><span class="p">,</span> <span class="n">models</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span><span class="p">,</span> <span class="n">HpmStruct</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="n">read_ckpt_state_dict</span><span class="p">,</span> \
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="n">read_ckpt_state_dict</span><span class="p">,</span> \
         <span class="n">load_checkpoint_to_model</span>
         <span class="n">load_checkpoint_to_model</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.kd_model_exceptions</span> <span class="kn">import</span> <span class="n">ArchitectureKwargsException</span><span class="p">,</span> \
    +<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.kd_trainer_exceptions</span> <span class="kn">import</span> <span class="n">ArchitectureKwargsException</span><span class="p">,</span> \
         <span class="n">UnsupportedKDArchitectureException</span><span class="p">,</span> <span class="n">InconsistentParamsException</span><span class="p">,</span> <span class="n">UnsupportedKDModelArgException</span><span class="p">,</span> \
         <span class="n">UnsupportedKDArchitectureException</span><span class="p">,</span> <span class="n">InconsistentParamsException</span><span class="p">,</span> <span class="n">UnsupportedKDModelArgException</span><span class="p">,</span> \
         <span class="n">TeacherKnowledgeException</span><span class="p">,</span> <span class="n">UndefinedNumClassesException</span>
         <span class="n">TeacherKnowledgeException</span><span class="p">,</span> <span class="n">UndefinedNumClassesException</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">KDModelMetricsUpdateCallback</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">KDModelMetricsUpdateCallback</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">KDModelEMA</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">KDModelEMA</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils.sg_trainer_utils</span> <span class="kn">import</span> <span class="n">parse_args</span>
    +
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="KDModel"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDModel">[docs]</a><span class="k">class</span> <span class="nc">KDModel</span><span class="p">(</span><span class="n">SgModel</span><span class="p">):</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    -        <span class="nb">super</span><span class="p">(</span><span class="n">KDModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    +<div class="viewcode-block" id="KDTrainer"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDTrainer">[docs]</a><span class="k">class</span> <span class="nc">KDTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="p">:</span> <span class="n">Union</span
    +                 <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">experiment_name</span><span class="o">=</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="o">=</span><span class="n">multi_gpu</span><span class="p">,</span> <span class=
             <span class="bp">self</span><span class="o">.</span><span class="n">student_architecture</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">student_architecture</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">teacher_architecture</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">teacher_architecture</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">student_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">student_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">teacher_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">teacher_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
     
     
    -<div class="viewcode-block" id="KDModel.build_model"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDModel.build_model">[docs]</a>    <span class="k">def</span> <span class="nf">build_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    -                    <span class="c1"># noqa: C901 - too complex</span>
    -                    <span class="n">architecture</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">KDModule</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;kd_module&#39;</span><span class="p">,</span>
    -                    <span class="n">arch_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">checkpoint_params</span><span class="o">=</span><span class="p">{},</span>
    -                    <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    +<div class="viewcode-block" id="KDTrainer.train_from_config"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDTrainer.train_from_config">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">train_from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">cfg</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DictConfig</span><span class="p">,</span> <span class="nb">dict</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        :param architecture: (Union[str, KDModule]) Defines the network&#39;s architecture from models/KD_ARCHITECTURES</span>
    -<span class="sd">         (default=&#39;kd_module&#39;)</span>
    -
    -<span class="sd">        :param arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc to be passed to kd</span>
    -<span class="sd">            architecture class (discarded when architecture is KDModule instance)</span>
    -
    -<span class="sd">        :param checkpoint_params: (dict) A dictionary like object with the following keys/values:</span>
    -
    -<span class="sd">              student_pretrained_weights:   String describing the dataset of the pretrained weights (for example</span>
    -<span class="sd">              &quot;imagenent&quot;) for the student network.</span>
    -
    -<span class="sd">              teacher_pretrained_weights:   String describing the dataset of the pretrained weights (for example</span>
    -<span class="sd">              &quot;imagenent&quot;) for the teacher network.</span>
    -
    -<span class="sd">              teacher_checkpoint_path:    Local path to the teacher&#39;s checkpoint. Note that when passing pretrained_weights</span>
    -<span class="sd">                                   through teacher_arch_params these weights will be overridden by the</span>
    -<span class="sd">                                   pretrained checkpoint. (default=None)</span>
    -
    -<span class="sd">              load_kd_model_checkpoint:   Whether to load an entire KDModule checkpoint (used to continue KD training)</span>
    -<span class="sd">               (default=False)</span>
    -
    -<span class="sd">              kd_model_source_ckpt_folder_name: Folder name to load an entire KDModule checkpoint from</span>
    -<span class="sd">                (self.experiment_name if none is given) to resume KD training (default=None)</span>
    +<span class="sd">        Trains according to cfg recipe configuration.</span>
     
     
    -<span class="sd">              kd_model_external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative</span>
    -<span class="sd">                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to</span>
    -<span class="sd">                                               load the checkpoint even if the load_checkpoint flag is not provided.</span>
    -<span class="sd">                                               (deafult=None)</span>
    -
    -<span class="sd">        :keyword student_architecture: (Union[str, SgModule]) Defines the student&#39;s architecture from</span>
    -<span class="sd">            models/ALL_ARCHITECTURES (when str), or directly defined the student network (when SgModule).</span>
    -
    -<span class="sd">        :keyword teacher_architecture: (Union[str, SgModule]) Defines the teacher&#39;s architecture from</span>
    -<span class="sd">            models/ALL_ARCHITECTURES (when str), or directly defined the teacher network (when SgModule).</span>
    -
    -<span class="sd">        :keyword student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student</span>
    -<span class="sd">            net. (deafult={})</span>
    +<span class="sd">        @param cfg: The parsed DictConfig from yaml recipe files</span>
    +<span class="sd">        @return: output of kd_trainer.train(...) (i.e results tuple)</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="c1"># INSTANTIATE ALL OBJECTS IN CFG</span>
    +        <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
     
     
    -<span class="sd">        :keyword teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher</span>
    -<span class="sd">            net. (deafult={})</span>
    +        <span class="n">kwargs</span> <span class="o">=</span> <span class="n">parse_args</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="fm">__init__</span><span class="p">)</span>
     
     
    -<span class="sd">        :keyword run_teacher_on_eval: (bool)- whether to run self.teacher at eval mode regardless of self.train(mode)</span>
    +        <span class="n">trainer</span> <span class="o">=</span> <span class="n">KDTrainer</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
     
     
    +        <span class="c1"># INSTANTIATE DATA LOADERS</span>
    +        <span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">train_dataloader</span><span class="p">,</span>
    +                                           <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataset_params</span><span class="p">,</span>
    +                                           <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataloader_params</span><span class="p">)</span>
     
     
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;student_architecture&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    -        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;teacher_architecture&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
    -        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;student_arch_params&quot;</span><span class="p">,</span> <span class="p">{})</span>
    -        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;teacher_arch_params&quot;</span><span class="p">,</span> <span class="p">{})</span>
    -        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;run_teacher_on_eval&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
    +        <span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">val_dataloader</span><span class="p">,</span>
    +                                         <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span><span class="p">,</span>
    +                                         <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataloader_params</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">_validate_args</span><span class="p">(</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    +        <span class="n">student</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_architecture</span><span class="p">,</span> <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_arch_params</span><span class="p">,</span>
    +                             <span class="n">strict_load</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">strict_load</span><span class="p">,</span>
    +                             <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
    +                             <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">checkpoint_path</span><span class="p">,</span>
    +                             <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">student_architecture</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;student_architecture&quot;</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">teacher_architecture</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;teacher_architecture&quot;</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">student_arch_params</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;student_arch_params&quot;</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">teacher_arch_params</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;teacher_arch_params&quot;</span><span class="p">)</span>
    +        <span class="n">teacher</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_architecture</span><span class="p">,</span> <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_arch_params</span><span class="p">,</span>
    +                             <span class="n">strict_load</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">strict_load</span><span class="p">,</span>
    +                             <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
    +                             <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">checkpoint_path</span><span class="p">,</span>
    +                             <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">)</span>
     
     
    -        <span class="nb">super</span><span class="p">(</span><span class="n">KDModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">build_model</span><span class="p">(</span><span class="n">architecture</span><span class="o">=</span><span class="n">architecture</span><span class="p">,</span> <span class="n">arch_params</span><span class="o">=</span><span class="n">arch_params</span><span class="p">,</span>
    -                                         <span class="n">checkpoint_params</span><span class="o">=</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
    +        <span class="c1"># TRAIN</span>
    +        <span class="n">trainer</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">training_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span><span class="p">,</span> <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class=
    +                      <span class="n">kd_architecture</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">architecture</span><span class="p">,</span> <span class="n">kd_arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
    +                      <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">run_teacher_on_eval</span><span class="p">,</span>
    +                      <span class="n">train_loader</span><span class="o">=</span><span class="n">train_dataloader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">)</span></div>
     
     
         <span class="k">def</span> <span class="nf">_validate_args</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_validate_args</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
             <span class="n">student_architecture</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;student_architecture&quot;</span><span class="p">)</span>
             <span class="n">student_architecture</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;student_architecture&quot;</span><span class="p">)</span>
    @@ -215,7 +201,8 @@
             <span class="n">load_kd_model_checkpoint</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;load_checkpoint&quot;</span><span class="p">)</span>
             <span class="n">load_kd_model_checkpoint</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;load_checkpoint&quot;</span><span class="p">)</span>
     
     
             <span class="c1"># CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD</span>
             <span class="c1"># CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD</span>
    -        <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">teacher_pretrained_weights</span> <span class="ow">or</span> <span class="n">teacher_checkpoint_path</span> <span class="ow">or</span> <span class="n">load_kd_model_checkpoint</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">teacher_architecture</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span cla
    +        <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">teacher_pretrained_weights</span> <span class="ow">or</span> <span class="n">teacher_checkpoint_path</span> <span class="ow">or</span> <span class="n">load_kd_model_checkpoint</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span>
    +                <span class="n">teacher_architecture</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">)):</span>
                 <span class="k">raise</span> <span class="n">TeacherKnowledgeException</span><span class="p">()</span>
                 <span class="k">raise</span> <span class="n">TeacherKnowledgeException</span><span class="p">()</span>
     
     
         <span class="k">def</span> <span class="nf">_validate_num_classes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">student_arch_params</span><span class="p">,</span> <span class="n">teacher_arch_params</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_validate_num_classes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">student_arch_params</span><span class="p">,</span> <span class="n">teacher_arch_params</span><span class="p">):</span>
    @@ -277,6 +264,9 @@
     
     
             <span class="n">run_teacher_on_eval</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;run_teacher_on_eval&quot;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
             <span class="n">run_teacher_on_eval</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;run_teacher_on_eval&quot;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
     
     
    +        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instantiate_kd_net</span><span class="p">(</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">run_teacher_on_eval</span><span class="p">,</span> <span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="p">)</span>
    +
    +    <span class="k">def</span> <span class="nf">_instantiate_kd_net</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">run_teacher_on_eval</span><span class="p">,</span> <span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="p">):</span>
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
                 <span class="n">architecture_cls</span> <span class="o">=</span> <span class="n">KD_ARCHITECTURES</span><span class="p">[</span><span class="n">architecture</span><span class="p">]</span>
                 <span class="n">architecture_cls</span> <span class="o">=</span> <span class="n">KD_ARCHITECTURES</span><span class="p">[</span><span class="n">architecture</span><span class="p">]</span>
                 <span class="n">net</span> <span class="o">=</span> <span class="n">architecture_cls</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class="p">,</span>
                 <span class="n">net</span> <span class="o">=</span> <span class="n">architecture_cls</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class="p">,</span>
    @@ -286,13 +276,12 @@
                                    <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">run_teacher_on_eval</span><span class="p">)</span>
                                    <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">run_teacher_on_eval</span><span class="p">)</span>
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="n">net</span> <span class="o">=</span> <span class="n">architecture</span>
                 <span class="n">net</span> <span class="o">=</span> <span class="n">architecture</span>
    -
             <span class="k">return</span> <span class="n">net</span>
             <span class="k">return</span> <span class="n">net</span>
     
     
         <span class="k">def</span> <span class="nf">_load_checkpoint_to_model</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_load_checkpoint_to_model</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for</span>
     <span class="sd">        Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for</span>
    -<span class="sd">         the entire KD network following the same logic as in SgModel.</span>
    +<span class="sd">         the entire KD network following the same logic as in Trainer.</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="n">teacher_checkpoint_path</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;teacher_checkpoint_path&quot;</span><span class="p">)</span>
             <span class="n">teacher_checkpoint_path</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;teacher_checkpoint_path&quot;</span><span class="p">)</span>
             <span class="n">teacher_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">teacher</span>
             <span class="n">teacher_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">teacher</span>
    @@ -315,7 +304,7 @@
                                          <span class="n">load_weights_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                                          <span class="n">load_weights_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                                          <span class="n">load_ema_as_net</span><span class="o">=</span><span class="n">load_teachers_ema</span><span class="p">)</span>
                                          <span class="n">load_ema_as_net</span><span class="o">=</span><span class="n">load_teachers_ema</span><span class="p">)</span>
     
     
    -        <span class="nb">super</span><span class="p">(</span><span class="n">KDModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span>
    +        <span class="nb">super</span><span class="p">(</span><span class="n">KDTrainer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span>
     
     
         <span class="k">def</span> <span class="nf">_add_metrics_update_callback</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_add_metrics_update_callback</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    @@ -337,7 +326,8 @@
                                        <span class="p">})</span>
                                        <span class="p">})</span>
             <span class="k">return</span> <span class="n">hyper_param_config</span>
             <span class="k">return</span> <span class="n">hyper_param_config</span>
     
     
    -    <span class="k">def</span> <span class="nf">_instantiate_ema_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span> <span class="n">exp_acti
    +    <span class="k">def</span> <span class="nf">_instantiate_ema_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span>
    +                               <span class="n">exp_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">KDModelEMA</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;Instantiate KD ema model for KDModule.</span>
             <span class="sd">&quot;&quot;&quot;Instantiate KD ema model for KDModule.</span>
     
     
     <span class="sd">        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.</span>
     <span class="sd">        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.</span>
    @@ -360,7 +350,37 @@
                 <span class="n">best_net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">student</span><span class="p">)</span>
                 <span class="n">best_net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">student</span><span class="p">)</span>
     
     
             <span class="n">state</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">best_net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
             <span class="n">state</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">best_net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
    +        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
    +
    +<div class="viewcode-block" id="KDTrainer.train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDTrainer.train">[docs]</a>    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">KDModule</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">training_
    +              <span class="n">teacher</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">kd_architecture</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">KDModule</span><span class="o">.</span><span class="vm">__class__</span><span class="p">,<
    +              <span class="n">kd_arch_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(),</span> <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +              <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Trains the student network (wrapped in KDModule network).</span>
    +
    +<span class="sd">        :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,</span>
    +<span class="sd">            student and teacher (default=None)</span>
    +<span class="sd">        :param training_params: dict, Same as in Trainer.train()</span>
    +<span class="sd">        :param student: SgModule - the student trainer</span>
    +<span class="sd">        :param teacher: torch.nn.Module- the teacher trainer</span>
    +<span class="sd">        :param kd_architecture: KDModule architecture to use, currently only &#39;kd_module&#39; is supported (default=&#39;kd_module&#39;).</span>
    +<span class="sd">        :param kd_arch_params: architecture params to pas to kd_architecture constructor.</span>
    +<span class="sd">        :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)</span>
    +<span class="sd">        :param train_loader: Dataloader for train set.</span>
    +<span class="sd">        :param valid_loader: Dataloader for validation.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">kd_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="ow">or</span> <span class="n">model</span>
    +        <span class="k">if</span> <span class="n">kd_net</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="k">if</span> <span class="n">student</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">teacher</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must pass student and teacher models or net (KDModule).&quot;</span><span class="p">)</span>
    +            <span class="n">kd_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instantiate_kd_net</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">kd_arch_params</span><span class="p">),</span>
    +                                              <span class="n">architecture</span><span class="o">=</span><span class="n">kd_architecture</span><span class="p">,</span>
    +                                              <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">run_teacher_on_eval</span><span class="p">,</span>
    +                                              <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span>
    +                                              <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class="p">)</span>
    +        <span class="nb">super</span><span class="p">(</span><span class="n">KDTrainer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">kd_net</span><span class="p">,</span> <span class="n">training_params</span><span class="o">=</span><span class="n">training_params</span><span class="p">,</span>
    +                                     <span class="n">train_loader</span><span class="o">=</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="o">=</span><span class="n">valid_loader</span><span class="p">)</span></div></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -390,4 +410,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.legacy.utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.legacy.utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.legacy.utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">os</span>
    84. <span class="kn">import</span> <span class="nn">torch</span>
    85. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
    86. <span class="kn">import</span> <span class="nn">torch.nn.init</span> <span class="k">as</span> <span class="nn">init</span>
    87. <div class="viewcode-block" id="prefetch_dataset"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.prefetch_dataset">[docs]</a><span class="k">def</span> <span class="nf">prefetch_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">half</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    88. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    89. <span class="n">tensors</span> <span class="o">=</span> <span class="n">dataset</span>
    90. <span class="k">else</span><span class="p">:</span>
    91. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
    92. <span class="n">dataset</span><span class="p">,</span>
    93. <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
    94. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">drop_last</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    95. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span> <span class="n">pin_memory</span><span class="o">=</span><span class="kc">False</span>
    96. <span class="p">)</span>
    97. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">]</span>
    98. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">tensors</span><span class="p">)]</span>
    99. <span class="k">if</span> <span class="n">device</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    100. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
    101. <span class="k">if</span> <span class="n">half</span><span class="p">:</span>
    102. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">half</span><span class="p">()</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">is_floating_point</span><span class="p">()</span> <span class="k">else</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
    103. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="o">*</span><span class="n">tensors</span><span class="p">)</span></div>
    104. <div class="viewcode-block" id="PrefetchDataLoader"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.PrefetchDataLoader">[docs]</a><span class="k">class</span> <span class="nc">PrefetchDataLoader</span><span class="p">:</span>
    105. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">half</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    106. <span class="bp">self</span><span class="o">.</span><span class="n">loader</span> <span class="o">=</span> <span class="n">dataloader</span>
    107. <span class="bp">self</span><span class="o">.</span><span class="n">iter</span> <span class="o">=</span> <span class="kc">None</span>
    108. <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
    109. <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span> <span class="k">if</span> <span class="n">half</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
    110. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span>
    111. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="kc">None</span>
    112. <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    113. <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loader</span><span class="p">)</span>
    114. <div class="viewcode-block" id="PrefetchDataLoader.async_prefech"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.PrefetchDataLoader.async_prefech">[docs]</a> <span class="k">def</span> <span class="nf">async_prefech</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    115. <span class="k">try</span><span class="p">:</span>
    116. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">iter</span><span class="p">)</span>
    117. <span class="k">except</span> <span class="ne">StopIteration</span><span class="p">:</span>
    118. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="kc">None</span>
    119. <span class="k">return</span>
    120. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">):</span>
    121. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">next_data</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    122. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    123. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">next_data</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
    124. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="p">[</span>
    125. <span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">is_floating_point</span><span class="p">()</span> <span class="k">else</span> <span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
    126. <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span>
    127. <span class="p">]</span></div>
    128. <span class="k">def</span> <span class="fm">__iter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    129. <span class="bp">self</span><span class="o">.</span><span class="n">iter</span> <span class="o">=</span> <span class="nb">iter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loader</span><span class="p">)</span>
    130. <span class="bp">self</span><span class="o">.</span><span class="n">async_prefech</span><span class="p">()</span>
    131. <span class="k">while</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    132. <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
    133. <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span>
    134. <span class="bp">self</span><span class="o">.</span><span class="n">async_prefech</span><span class="p">()</span>
    135. <span class="k">yield</span> <span class="n">data</span></div>
    136. <div class="viewcode-block" id="init_params"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.init_params">[docs]</a><span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="n">net</span><span class="p">):</span>
    137. <span class="sd">&quot;&quot;&quot;Init layer parameters.&quot;&quot;&quot;</span>
    138. <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">net</span><span class="o">.</span><span class="n">modules</span><span class="p">():</span>
    139. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">):</span>
    140. <span class="n">init</span><span class="o">.</span><span class="n">kaiming_normal</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;fan_out&#39;</span><span class="p">)</span>
    141. <span class="c1"># if m.bias:</span>
    142. <span class="c1"># init.constant(m.bias, -5)</span>
    143. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">):</span>
    144. <span class="n">init</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    145. <span class="n">init</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    146. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">):</span>
    147. <span class="n">init</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
    148. <span class="k">if</span> <span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span>
    149. <span class="n">init</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></div>
    150. <div class="viewcode-block" id="format_time"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.format_time">[docs]</a><span class="k">def</span> <span class="nf">format_time</span><span class="p">(</span><span class="n">seconds</span><span class="p">):</span>
    151. <span class="n">days</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">/</span> <span class="mi">3600</span> <span class="o">/</span> <span class="mi">24</span><span class="p">)</span>
    152. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">days</span> <span class="o">*</span> <span class="mi">3600</span> <span class="o">*</span> <span class="mi">24</span>
    153. <span class="n">hours</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">/</span> <span class="mi">3600</span><span class="p">)</span>
    154. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">hours</span> <span class="o">*</span> <span class="mi">3600</span>
    155. <span class="n">minutes</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">/</span> <span class="mi">60</span><span class="p">)</span>
    156. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">minutes</span> <span class="o">*</span> <span class="mi">60</span>
    157. <span class="n">secondsf</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span><span class="p">)</span>
    158. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">secondsf</span>
    159. <span class="n">millis</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">*</span> <span class="mi">1000</span><span class="p">)</span>
    160. <span class="n">f</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
    161. <span class="n">i</span> <span class="o">=</span> <span class="mi">1</span>
    162. <span class="k">if</span> <span class="n">days</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    163. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">days</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;D&#39;</span>
    164. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
    165. <span class="k">if</span> <span class="n">hours</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
    166. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">hours</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;h&#39;</span>
    167. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
    168. <span class="k">if</span> <span class="n">minutes</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
    169. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">minutes</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;m&#39;</span>
    170. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
    171. <span class="k">if</span> <span class="n">secondsf</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
    172. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">secondsf</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;s&#39;</span>
    173. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
    174. <span class="k">if</span> <span class="n">millis</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
    175. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">millis</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;ms&#39;</span>
    176. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
    177. <span class="k">if</span> <span class="n">f</span> <span class="o">==</span> <span class="s1">&#39;&#39;</span><span class="p">:</span>
    178. <span class="n">f</span> <span class="o">=</span> <span class="s1">&#39;0ms&#39;</span>
    179. <span class="k">return</span> <span class="n">f</span></div>
    180. <div class="viewcode-block" id="is_better"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.is_better">[docs]</a><span class="k">def</span> <span class="nf">is_better</span><span class="p">(</span><span class="n">new_metric</span><span class="p">,</span> <span class="n">current_best_metric</span><span class="p">,</span> <span class="n">metric_to_watch</span><span class="o">=</span><span class="s1">&#39;acc&#39;</span><span class="p">):</span>
    181. <span class="sd">&quot;&quot;&quot;</span>
    182. <span class="sd"> Determines which of the two metrics is better, the higher if watching acc or lower when watching loss</span>
    183. <span class="sd"> :param new_metric: the new metric</span>
    184. <span class="sd"> :param current_best_metric: the compared to metric</span>
    185. <span class="sd"> :param metric_to_watch: acc or loss</span>
    186. <span class="sd"> :return: bool, True if new metric is better than current</span>
    187. <span class="sd"> &quot;&quot;&quot;</span>
    188. <span class="k">return</span> <span class="n">metric_to_watch</span> <span class="o">==</span> <span class="s1">&#39;acc&#39;</span> <span class="ow">and</span> <span class="n">new_metric</span> <span class="o">&gt;</span> <span class="n">current_best_metric</span> <span class="ow">or</span> <span class="p">(</span><span class="n">metric_to_watch</span> <span class="o">==</span> <span class="s1">&#39;loss&#39;</span> <span class="ow">and</span> <span class="n">current_best_metric</span> <span class="o">&gt;</span> <span class="n">new_metric</span><span class="p">)</span></div>
    189. <div class="viewcode-block" id="makedirs_if_not_exists"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.makedirs_if_not_exists">[docs]</a><span class="k">def</span> <span class="nf">makedirs_if_not_exists</span><span class="p">(</span><span class="n">dir_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    190. <span class="sd">&quot;&quot;&quot;</span>
    191. <span class="sd"> make new directory in dir_path if it doesn&#39;t exists</span>
    192. <span class="sd"> :param dir_path - full path of directory</span>
    193. <span class="sd"> &quot;&quot;&quot;</span>
    194. <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">dir_path</span><span class="p">):</span>
    195. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">dir_path</span><span class="p">)</span></div>
    196. </pre></div>
    197. </div>
    198. </div>
    199. <footer>
    200. <hr/>
    201. <div role="contentinfo">
    202. <p>&#169; Copyright 2021, SuperGradients team.</p>
    203. </div>
    204. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    205. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    206. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    207. </footer>
    208. </div>
    209. </div>
    210. </section>
    211. </div>
    212. <script>
    213. jQuery(function () {
    214. SphinxRtdTheme.Navigation.enable(true);
    215. });
    216. </script>
    217. </body>
    218. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.bce_dice_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.bce_dice_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -92,7 +94,7 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses.dice_loss</span> <span class="kn">import</span> <span class="n">BinaryDiceLoss</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses.dice_loss</span> <span class="kn">import</span> <span class="n">BinaryDiceLoss</span>
     
     
     
     
    -<div class="viewcode-block" id="BCEDiceLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.BCEDiceLoss">[docs]</a><span class="k">class</span> <span class="nc">BCEDiceLoss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    +<div class="viewcode-block" id="BCEDiceLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.BCEDiceLoss">[docs]</a><span class="k">class</span> <span class="nc">BCEDiceLoss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Binary Cross Entropy + Dice Loss</span>
     <span class="sd">    Binary Cross Entropy + Dice Loss</span>
     
     
    @@ -108,7 +110,7 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">bce</span> <span class="o">=</span> <span class="n">BCE</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">bce</span> <span class="o">=</span> <span class="n">BCE</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">dice</span> <span class="o">=</span> <span class="n">BinaryDiceLoss</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">dice</span> <span class="o">=</span> <span class="n">BinaryDiceLoss</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="BCEDiceLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.BCEDiceLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <spa
    +<div class="viewcode-block" id="BCEDiceLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.BCEDiceLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     
     
     <span class="sd">        @param input: Network&#39;s raw output shaped (N,1,H,W)</span>
     <span class="sd">        @param input: Network&#39;s raw output shaped (N,1,H,W)</span>
    @@ -145,4 +147,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.losses.ddrnet_loss &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.losses.ddrnet_loss</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.losses.ddrnet_loss</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    84. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.training.losses.ohem_ce_loss</span> <span class="kn">import</span> <span class="n">OhemCELoss</span>
    86. <div class="viewcode-block" id="DDRNetLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ddrnet_loss.DDRNetLoss">[docs]</a><span class="k">class</span> <span class="nc">DDRNetLoss</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">):</span>
    87. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    88. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span>
    89. <span class="n">ohem_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    90. <span class="n">weights</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span>
    91. <span class="n">ignore_label</span><span class="o">=</span><span class="mi">255</span><span class="p">,</span>
    92. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    93. <span class="sd">&quot;&quot;&quot;</span>
    94. <span class="sd"> This loss is an extension of the Ohem (Online Hard Example Mining Cross Entropy) Loss.</span>
    95. <span class="sd"> as define in paper:</span>
    96. <span class="sd"> Accurate Semantic Segmentation of Road Scenes ( https://arxiv.org/pdf/2101.06085.pdf )</span>
    97. <span class="sd"> :param threshold: threshold to th hard example mining algorithm</span>
    98. <span class="sd"> :param ohem_percentage: minimum percentage of total pixels for the hard example mining algorithm</span>
    99. <span class="sd"> (taking only the largest) losses</span>
    100. <span class="sd"> :param weights: weights per each input of the loss. This loss supports a multi output (like in DDRNet with</span>
    101. <span class="sd"> an auxiliary head). the losses of each head can be weighted.</span>
    102. <span class="sd"> :param ignore_label: targets label to be ignored</span>
    103. <span class="sd"> :param num_pixels_exclude_ignored: whether to exclude ignore pixels when calculating the mining percentage.</span>
    104. <span class="sd"> see OhemCELoss doc for more details.</span>
    105. <span class="sd"> &quot;&quot;&quot;</span>
    106. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">mining_percent</span><span class="o">=</span><span class="n">ohem_percentage</span><span class="p">,</span> <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_label</span><span class="p">,</span>
    107. <span class="n">num_pixels_exclude_ignored</span><span class="o">=</span><span class="n">num_pixels_exclude_ignored</span><span class="p">)</span>
    108. <span class="bp">self</span><span class="o">.</span><span class="n">weights</span> <span class="o">=</span> <span class="n">weights</span>
    109. <div class="viewcode-block" id="DDRNetLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ddrnet_loss.DDRNetLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions_list</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
    110. <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    111. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    112. <span class="n">predictions_list</span> <span class="o">=</span> <span class="p">(</span><span class="n">predictions_list</span><span class="p">,)</span>
    113. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">),</span> <span class="s2">&quot;num of prediction must be the same as num of loss weights&quot;</span>
    114. <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
    115. <span class="n">unweighted_losses</span> <span class="o">=</span> <span class="p">[]</span>
    116. <span class="k">for</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">weight</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">):</span>
    117. <span class="n">unweighted_loss</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
    118. <span class="n">unweighted_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">unweighted_loss</span><span class="p">)</span>
    119. <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">unweighted_loss</span> <span class="o">*</span> <span class="n">weight</span><span class="p">)</span>
    120. <span class="n">total_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
    121. <span class="n">unweighted_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
    122. <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">unweighted_losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
    123. </pre></div>
    124. </div>
    125. </div>
    126. <footer>
    127. <hr/>
    128. <div role="contentinfo">
    129. <p>&#169; Copyright 2021, SuperGradients team.</p>
    130. </div>
    131. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    132. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    133. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    134. </footer>
    135. </div>
    136. </div>
    137. </section>
    138. </div>
    139. <script>
    140. jQuery(function () {
    141. SphinxRtdTheme.Navigation.enable(true);
    142. });
    143. </script>
    144. </body>
    145. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.dice_ce_edge_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.dice_ce_edge_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -95,7 +97,7 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses.mask_loss</span> <span class="kn">import</span> <span class="n">MaskAttentionLoss</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses.mask_loss</span> <span class="kn">import</span> <span class="n">MaskAttentionLoss</span>
     
     
     
     
    -<div class="viewcode-block" id="DiceCEEdgeLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.DiceCEEdgeLoss">[docs]</a><span class="k">class</span> <span class="nc">DiceCEEdgeLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    +<div class="viewcode-block" id="DiceCEEdgeLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.DiceCEEdgeLoss">[docs]</a><span class="k">class</span> <span class="nc">DiceCEEdgeLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                      <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
                      <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
                      <span class="n">num_aux_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
                      <span class="n">num_aux_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
    @@ -151,7 +153,22 @@
             <span class="p">)</span>
             <span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">dice_loss</span> <span class="o">=</span> <span class="n">DiceLoss</span><span class="p">(</span><span class="n">apply_softmax</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">dice_loss</span> <span class="o">=</span> <span class="n">DiceLoss</span><span class="p">(</span><span class="n">apply_softmax</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="DiceCEEdgeLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.DiceCEEdgeLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><
    +    <span class="nd">@property</span>
    +    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Component names for logging during training.</span>
    +<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
    +<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;main_loss&quot;</span><span class="p">]</span>
    +        <span class="c1"># Append aux losses names</span>
    +        <span class="n">names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s2">&quot;aux_loss</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_aux_heads</span><span class="p">)]</span
    +        <span class="c1"># Append detail losses names</span>
    +        <span class="n">names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s2">&quot;detail_loss</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_detail_heads</span><span class="p">)]
    +        <span class="n">names</span> <span class="o">+=</span> <span class="p">[</span><span class="s2">&quot;loss&quot;</span><span class="p">]</span>
    +        <span class="k">return</span> <span class="n">names</span>
    +
    +<div class="viewcode-block" id="DiceCEEdgeLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.DiceCEEdgeLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span cl
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        :param preds: Model output predictions, must be in the followed format:</span>
     <span class="sd">        :param preds: Model output predictions, must be in the followed format:</span>
     <span class="sd">         [Main-feats, Aux-feats[0], ..., Aux-feats[num_auxs-1], Detail-feats[0], ..., Detail-feats[num_details-1]</span>
     <span class="sd">         [Main-feats, Aux-feats[0], ..., Aux-feats[num_auxs-1], Detail-feats[0], ..., Detail-feats[num_details-1]</span>
    @@ -214,4 +231,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.focal_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.focal_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -91,7 +93,7 @@
     <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
     <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
     
     
     
     
    -<div class="viewcode-block" id="FocalLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.FocalLoss">[docs]</a><span class="k">class</span> <span class="nc">FocalLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    +<div class="viewcode-block" id="FocalLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.FocalLoss">[docs]</a><span class="k">class</span> <span class="nc">FocalLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)&quot;&quot;&quot;</span>
     
     
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss_fcn</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="mf">1.5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.25</span><span
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss_fcn</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="mf">1.5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.25</span><span
    @@ -102,7 +104,7 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="s1">&#39;none&#39;</span>  <span class="c1"># required to apply FocalLoss to each element</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="s1">&#39;none&#39;</span>  <span class="c1"># required to apply FocalLoss to each element</span>
     
     
    -<div class="viewcode-block" id="FocalLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.FocalLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">):</span>
    +<div class="viewcode-block" id="FocalLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.FocalLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">):</span>
             <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">)</span>
             <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">)</span>
     
     
             <span class="n">pred_prob</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>  <span class="c1"># prob from logits</span>
             <span class="n">pred_prob</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>  <span class="c1"># prob from logits</span>
    @@ -146,4 +148,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.kd_losses &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.kd_losses &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -100,7 +102,7 @@
                                                     <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">teacher_output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
                                                     <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">teacher_output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
     
     
     
     
    -<div class="viewcode-block" id="KDLogitsLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.KDLogitsLoss">[docs]</a><span class="k">class</span> <span class="nc">KDLogitsLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    +<div class="viewcode-block" id="KDLogitsLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.KDLogitsLoss">[docs]</a><span class="k">class</span> <span class="nc">KDLogitsLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot; Knowledge distillation loss, wraps the task loss and distillation loss &quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot; Knowledge distillation loss, wraps the task loss and distillation loss &quot;&quot;&quot;</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">task_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span><span class="p">,</span> <span class="n">distillation_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span> <span class="o">=</span> <span class="n">KDklDivLoss</span><span class="p">(),</span> <span class="n">distillation_loss_coeff</span><span class="p">:
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">task_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span><span class="p">,</span> <span class="n">distillation_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span> <span class="o">=</span> <span class="n">KDklDivLoss</span><span class="p">(),</span> <span class="n">distillation_loss_coeff</span><span class="p">:
             <span class="sd">&#39;&#39;&#39;</span>
             <span class="sd">&#39;&#39;&#39;</span>
    @@ -114,7 +116,16 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_fn</span> <span class="o">=</span> <span class="n">distillation_loss_fn</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_fn</span> <span class="o">=</span> <span class="n">distillation_loss_fn</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_coeff</span> <span class="o">=</span> <span class="n">distillation_loss_coeff</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_coeff</span> <span class="o">=</span> <span class="n">distillation_loss_coeff</span>
     
     
    -<div class="viewcode-block" id="KDLogitsLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.KDLogitsLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kd_module_output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
    +    <span class="nd">@property</span>
    +    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Component names for logging during training.</span>
    +<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
    +<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Task Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Distillation Loss&quot;</span><span class="p">]</span>
    +
    +<div class="viewcode-block" id="KDLogitsLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.KDLogitsLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kd_module_output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
             <span class="n">task_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">task_loss_fn</span><span class="p">(</span><span class="n">kd_module_output</span><span class="o">.</span><span class="n">student_output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
             <span class="n">task_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">task_loss_fn</span><span class="p">(</span><span class="n">kd_module_output</span><span class="o">.</span><span class="n">student_output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">task_loss</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>  <span class="c1"># SOME LOSS FUNCTIONS RETURNS LOSS AND LOG_ITEMS</span>
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">task_loss</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>  <span class="c1"># SOME LOSS FUNCTIONS RETURNS LOSS AND LOG_ITEMS</span>
                 <span class="n">task_loss</span> <span class="o">=</span> <span class="n">task_loss</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
                 <span class="n">task_loss</span> <span class="o">=</span> <span class="n">task_loss</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    @@ -151,4 +162,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.label_smoothing_cross_entropy_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.label_smoothing_cross_entropy_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -91,7 +93,7 @@
     <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
     <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
     
     
     
     
    -<div class="viewcode-block" id="onehot"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.onehot">[docs]</a><span class="k">def</span> <span class="nf">onehot</span><span class="p">(</span><span class="n">indexes</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="kc">None</span><span
    +<span class="k">def</span> <span class="nf">onehot</span><span class="p">(</span><span class="n">indexes</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Creates a one-hot representation of indexes with N possible entries</span>
     <span class="sd">    Creates a one-hot representation of indexes with N possible entries</span>
     <span class="sd">    if N is not specified, it will suit the maximum index appearing.</span>
     <span class="sd">    if N is not specified, it will suit the maximum index appearing.</span>
    @@ -105,7 +107,7 @@
         <span class="n">output</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
         <span class="n">output</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
         <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">ignore_index</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">ignore_index</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
             <span class="n">output</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">indexes</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">ignore_index</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
             <span class="n">output</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">indexes</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">ignore_index</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
    -    <span class="k">return</span> <span class="n">output</span></div>
    +    <span class="k">return</span> <span class="n">output</span>
     
     
     
     
     <span class="k">def</span> <span class="nf">_is_long</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_is_long</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    @@ -114,7 +116,7 @@
         <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class
         <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class
     
     
     
     
    -<div class="viewcode-block" id="cross_entropy"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.cross_entropy">[docs]</a><span class="k">def</span> <span class="nf">cross_entropy</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="
    +<span class="k">def</span> <span class="nf">cross_entropy</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;mean&#3
                       <span class="n">smooth_eps</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">smooth_dist</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
                       <span class="n">smooth_eps</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">smooth_dist</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567&quot;&quot;&quot;</span>
         <span class="n">smooth_eps</span> <span class="o">=</span> <span class="n">smooth_eps</span> <span class="ow">or</span> <span class="mi">0</span>
         <span class="n">smooth_eps</span> <span class="o">=</span> <span class="n">smooth_eps</span> <span class="ow">or</span> <span class="mi">0</span>
    @@ -166,10 +168,10 @@
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="n">masked_indices</span><span class="o">.</span><span class="n"
                 <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="n">masked_indices</span><span class="o">.</span><span class="n"
     
     
    -    <span class="k">return</span> <span class="n">loss</span></div>
    +    <span class="k">return</span> <span class="n">loss</span>
     
     
     
     
    -<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss">[docs]</a><span class="k">class</span> <span class="nc">LabelSmoothingCrossEntropyLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
    +<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss">[docs]</a><span class="k">class</span> <span class="nc">LabelSmoothingCrossEntropyLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing&quot;&quot;&quot;</span>
     
     
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;mean&#39;</span><span class="p">,</span> <span class="n">smooth
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;mean&#39;</span><span class="p">,</span> <span class="n">smooth
    @@ -180,7 +182,7 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span> <span class="o">=</span> <span class="n">smooth_dist</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span> <span class="o">=</span> <span class="n">smooth_dist</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">from_logits</span> <span class="o">=</span> <span class="n">from_logits</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">from_logits</span> <span class="o">=</span> <span class="n">from_logits</span>
     
     
    -<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">sm
    +<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">smooth_di
             <span class="k">if</span> <span class="n">smooth_dist</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">smooth_dist</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                 <span class="n">smooth_dist</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span>
                 <span class="n">smooth_dist</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span>
             <span class="n">loss</span> <span class="o">=</span> <span class="n">cross_entropy</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span>
             <span class="n">loss</span> <span class="o">=</span> <span class="n">cross_entropy</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span>
    @@ -218,4 +220,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.losses.ohem_ce_loss &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.losses.ohem_ce_loss</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.losses.ohem_ce_loss</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    85. <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
    86. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.loss_exceptions</span> <span class="kn">import</span> <span class="n">IllegalRangeForLossAttributeException</span><span class="p">,</span> <span class="n">RequiredLossComponentReductionException</span>
    87. <div class="viewcode-block" id="OhemLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemLoss">[docs]</a><span class="k">class</span> <span class="nc">OhemLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    88. <span class="sd">&quot;&quot;&quot;</span>
    89. <span class="sd"> OhemLoss - Online Hard Example Mining Cross Entropy Loss</span>
    90. <span class="sd"> &quot;&quot;&quot;</span>
    91. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    92. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
    93. <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    94. <span class="n">ignore_lb</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span><span class="p">,</span>
    95. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    96. <span class="n">criteria</span><span class="p">:</span> <span class="n">_Loss</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    97. <span class="sd">&quot;&quot;&quot;</span>
    98. <span class="sd"> :param threshold: Sample below probability threshold, is considered hard.</span>
    99. <span class="sd"> :param num_pixels_exclude_ignored: How to calculate total pixels from which extract mining percent of the</span>
    100. <span class="sd"> samples.</span>
    101. <span class="sd"> :param ignore_lb: label index to be ignored in loss calculation.</span>
    102. <span class="sd"> :param criteria: loss to mine the examples from.</span>
    103. <span class="sd"> i.e for num_pixels=100, ignore_pixels=30, mining_percent=0.1:</span>
    104. <span class="sd"> num_pixels_exclude_ignored=False =&gt; num_mining = 100 * 0.1 = 10</span>
    105. <span class="sd"> num_pixels_exclude_ignored=True =&gt; num_mining = (100 - 30) * 0.1 = 7</span>
    106. <span class="sd"> &quot;&quot;&quot;</span>
    107. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    108. <span class="k">if</span> <span class="n">mining_percent</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">mining_percent</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    109. <span class="k">raise</span> <span class="n">IllegalRangeForLossAttributeException</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="s2">&quot;mining percent&quot;</span><span class="p">)</span>
    110. <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">threshold</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">))</span>
    111. <span class="bp">self</span><span class="o">.</span><span class="n">mining_percent</span> <span class="o">=</span> <span class="n">mining_percent</span>
    112. <span class="bp">self</span><span class="o">.</span><span class="n">ignore_lb</span> <span class="o">=</span> <span class="n">ignore_lb</span>
    113. <span class="bp">self</span><span class="o">.</span><span class="n">num_pixels_exclude_ignored</span> <span class="o">=</span> <span class="n">num_pixels_exclude_ignored</span>
    114. <span class="k">if</span> <span class="n">criteria</span><span class="o">.</span><span class="n">reduction</span> <span class="o">!=</span> <span class="s1">&#39;none&#39;</span><span class="p">:</span>
    115. <span class="k">raise</span> <span class="n">RequiredLossComponentReductionException</span><span class="p">(</span><span class="s2">&quot;criteria&quot;</span><span class="p">,</span> <span class="n">criteria</span><span class="o">.</span><span class="n">reduction</span><span class="p">,</span> <span class="s1">&#39;none&#39;</span><span class="p">)</span>
    116. <span class="bp">self</span><span class="o">.</span><span class="n">criteria</span> <span class="o">=</span> <span class="n">criteria</span>
    117. <div class="viewcode-block" id="OhemLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
    118. <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">criteria</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    119. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_pixels_exclude_ignored</span><span class="p">:</span>
    120. <span class="c1"># remove ignore label elements</span>
    121. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="p">[</span><span class="n">labels</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_lb</span><span class="p">]</span>
    122. <span class="c1"># num pixels in a batch -&gt; num_pixels = batch_size * width * height - ignore_pixels</span>
    123. <span class="n">num_pixels</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
    124. <span class="k">else</span><span class="p">:</span>
    125. <span class="n">num_pixels</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
    126. <span class="c1"># if all pixels are ignore labels, return empty loss tensor</span>
    127. <span class="k">if</span> <span class="n">num_pixels</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    128. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.</span><span class="p">])</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
    129. <span class="n">num_mining</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mining_percent</span> <span class="o">*</span> <span class="n">num_pixels</span><span class="p">)</span>
    130. <span class="c1"># in case mining_percent=1, prevent out of bound exception</span>
    131. <span class="n">num_mining</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_mining</span><span class="p">,</span> <span class="n">num_pixels</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
    132. <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    133. <span class="n">loss</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    134. <span class="k">if</span> <span class="n">loss</span><span class="p">[</span><span class="n">num_mining</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span><span class="p">:</span>
    135. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="p">[</span><span class="n">loss</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span><span class="p">]</span>
    136. <span class="k">else</span><span class="p">:</span>
    137. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="p">[:</span><span class="n">num_mining</span><span class="p">]</span>
    138. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span></div></div>
    139. <div class="viewcode-block" id="OhemCELoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemCELoss">[docs]</a><span class="k">class</span> <span class="nc">OhemCELoss</span><span class="p">(</span><span class="n">OhemLoss</span><span class="p">):</span>
    140. <span class="sd">&quot;&quot;&quot;</span>
    141. <span class="sd"> OhemLoss - Online Hard Example Mining Cross Entropy Loss</span>
    142. <span class="sd"> &quot;&quot;&quot;</span>
    143. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    144. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
    145. <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    146. <span class="n">ignore_lb</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span><span class="p">,</span>
    147. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
    148. <span class="n">ignore_lb</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span> <span class="k">if</span> <span class="n">ignore_lb</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">ignore_lb</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">ignore_lb</span>
    149. <span class="n">criteria</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">)</span>
    150. <span class="nb">super</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
    151. <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span>
    152. <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">,</span>
    153. <span class="n">num_pixels_exclude_ignored</span><span class="o">=</span><span class="n">num_pixels_exclude_ignored</span><span class="p">,</span>
    154. <span class="n">criteria</span><span class="o">=</span><span class="n">criteria</span><span class="p">)</span></div>
    155. <div class="viewcode-block" id="OhemBCELoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemBCELoss">[docs]</a><span class="k">class</span> <span class="nc">OhemBCELoss</span><span class="p">(</span><span class="n">OhemLoss</span><span class="p">):</span>
    156. <span class="sd">&quot;&quot;&quot;</span>
    157. <span class="sd"> OhemBCELoss - Online Hard Example Mining Binary Cross Entropy Loss</span>
    158. <span class="sd"> &quot;&quot;&quot;</span>
    159. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    160. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
    161. <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    162. <span class="n">ignore_lb</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span><span class="p">,</span>
    163. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="p">):</span>
    164. <span class="nb">super</span><span class="p">(</span><span class="n">OhemBCELoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
    165. <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span>
    166. <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">,</span>
    167. <span class="n">num_pixels_exclude_ignored</span><span class="o">=</span><span class="n">num_pixels_exclude_ignored</span><span class="p">,</span>
    168. <span class="n">criteria</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">))</span>
    169. <div class="viewcode-block" id="OhemBCELoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemBCELoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
    170. <span class="c1"># REMOVE SINGLE CLASS CHANNEL WHEN DEALING WITH BINARY DATA</span>
    171. <span class="k">if</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
    172. <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    173. <span class="k">return</span> <span class="nb">super</span><span class="p">(</span><span class="n">OhemBCELoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="o">.</span><span class="n">float</span><span class="p">())</span></div></div>
    174. </pre></div>
    175. </div>
    176. </div>
    177. <footer>
    178. <hr/>
    179. <div role="contentinfo">
    180. <p>&#169; Copyright 2021, SuperGradients team.</p>
    181. </div>
    182. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    183. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    184. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    185. </footer>
    186. </div>
    187. </div>
    188. </section>
    189. </div>
    190. <script>
    191. jQuery(function () {
    192. SphinxRtdTheme.Navigation.enable(true);
    193. });
    194. </script>
    195. </body>
    196. </html>
    Discard
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.r_squared_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.r_squared_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -95,9 +97,9 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">convert_to_tensor</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">convert_to_tensor</span>
     
     
     
     
    -<div class="viewcode-block" id="RSquaredLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.RSquaredLoss">[docs]</a><span class="k">class</span> <span class="nc">RSquaredLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    +<div class="viewcode-block" id="RSquaredLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.RSquaredLoss">[docs]</a><span class="k">class</span> <span class="nc">RSquaredLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
     
     
    -<div class="viewcode-block" id="RSquaredLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.RSquaredLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
    +<div class="viewcode-block" id="RSquaredLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.RSquaredLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
             <span class="c1"># FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)</span>
             <span class="c1"># FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)</span>
             <span class="sd">&quot;&quot;&quot;Computes the R-squared for the output and target values</span>
             <span class="sd">&quot;&quot;&quot;Computes the R-squared for the output and target values</span>
     <span class="sd">        :param output: Tensor / Numpy / List</span>
     <span class="sd">        :param output: Tensor / Numpy / List</span>
    @@ -140,4 +142,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.shelfnet_ohem_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.shelfnet_ohem_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -91,7 +93,7 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses.ohem_ce_loss</span> <span class="kn">import</span> <span class="n">OhemCELoss</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses.ohem_ce_loss</span> <span class="kn">import</span> <span class="n">OhemCELoss</span>
     
     
     
     
    -<div class="viewcode-block" id="ShelfNetOHEMLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetOHEMLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetOHEMLoss</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">):</span>
    +<div class="viewcode-block" id="ShelfNetOHEMLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetOHEMLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetOHEMLoss</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span> <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="n">ignore_lb
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span> <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="n">ignore_lb
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        This loss is an extension of the Ohem (Online Hard Example Mining Cross Entropy) Loss.</span>
     <span class="sd">        This loss is an extension of the Ohem (Online Hard Example Mining Cross Entropy) Loss.</span>
    @@ -104,14 +106,23 @@
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span> <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span> <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="ShelfNetOHEMLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetOHEMLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions_list</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">targets</span><span c
    +<div class="viewcode-block" id="ShelfNetOHEMLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetOHEMLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions_list</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">targets</span><span class="p
             <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="k">for</span> <span class="n">predictions</span> <span class="ow">in</span> <span class="n">predictions_list</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">predictions</span> <span class="ow">in</span> <span class="n">predictions_list</span><span class="p">:</span>
                 <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">))</span>
                 <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">))</span>
             <span class="n">total_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
             <span class="n">total_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
             <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
             <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
     
     
    -        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
    +        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div>
    +
    +    <span class="nd">@property</span>
    +    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Component names for logging during training.</span>
    +<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
    +<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;Loss1/4&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss1/8&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss1/16&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss&quot;</span><span class="p">]</span></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -141,4 +152,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.shelfnet_semantic_encoding_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.shelfnet_semantic_encoding_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -91,7 +93,7 @@
     <span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
     <span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
     
     
     
     
    -<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetSemanticEncodingLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
    +<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetSemanticEncodingLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;2D Cross Entropy Loss with Auxilary Loss&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;2D Cross Entropy Loss with Auxilary Loss&quot;&quot;&quot;</span>
     
     
         <span class="c1"># FIXME - THIS LOSS SHOULD BE CHANGED TO SUPPORT APEX</span>
         <span class="c1"># FIXME - THIS LOSS SHOULD BE CHANGED TO SUPPORT APEX</span>
    @@ -104,7 +106,7 @@
             <span class="c1"># FIXME - TEST CODE LOTEM, CHANGED IN ORDER TO WORK WITH apex.amp</span>
             <span class="c1"># FIXME - TEST CODE LOTEM, CHANGED IN ORDER TO WORK WITH apex.amp</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCELoss</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCELoss</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
    +<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
             <span class="n">pred1</span><span class="p">,</span> <span class="n">se_pred</span><span class="p">,</span> <span class="n">pred2</span> <span class="o">=</span> <span class="n">logits</span>
             <span class="n">pred1</span><span class="p">,</span> <span class="n">se_pred</span><span class="p">,</span> <span class="n">pred2</span> <span class="o">=</span> <span class="n">logits</span>
     
     
             <span class="n">batch</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
             <span class="n">batch</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    @@ -122,7 +124,16 @@
             <span class="n">loss3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">se_pred</span><span class="p">),</span> <span class="n">se_target</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">cuda</span><span class="p">())</span
             <span class="n">loss3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">se_pred</span><span class="p">),</span> <span class="n">se_target</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">cuda</span><span class="p">())</span
             <span class="n">total_loss</span> <span class="o">=</span> <span class="n">loss1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_weight</span> <span class="o">*</span> <span class="n">loss2</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">se_weight</span> <span class="o">*</span> <span class="n">loss3</span>
             <span class="n">total_loss</span> <span class="o">=</span> <span class="n">loss1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_weight</span> <span class="o">*</span> <span class="n">loss2</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">se_weight</span> <span class="o">*</span> <span class="n">loss3</span>
             <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">,</span> <span class="n">loss3</span><span class="p">,</span> <span class="n">total_loss</span><span class="p">]</span>
             <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">,</span> <span class="n">loss3</span><span class="p">,</span> <span class="n">total_loss</span><span class="p">]</span>
    -        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
    +        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div>
    +
    +    <span class="nd">@property</span>
    +    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Component names for logging during training.</span>
    +<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
    +<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;loss1&quot;</span><span class="p">,</span> <span class="s2">&quot;loss2&quot;</span><span class="p">,</span> <span class="s2">&quot;loss3&quot;</span><span class="p">,</span> <span class="s2">&quot;total_loss&quot;</span><span class="p">]</span></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -152,4 +163,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.ssd_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.ssd_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -96,7 +98,7 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ssd_utils</span> <span class="kn">import</span> <span class="n">DefaultBoxes</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ssd_utils</span> <span class="kn">import</span> <span class="n">DefaultBoxes</span>
     
     
     
     
    -<div class="viewcode-block" id="HardMiningCrossEntropyLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.HardMiningCrossEntropyLoss">[docs]</a><span class="k">class</span> <span class="nc">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    +<span class="k">class</span> <span class="nc">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    L_cls = [CE of all positives] + [CE of the hardest backgrounds]</span>
     <span class="sd">    L_cls = [CE of all positives] + [CE of the hardest backgrounds]</span>
     <span class="sd">    where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE</span>
     <span class="sd">    where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE</span>
    @@ -113,7 +115,7 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">neg_pos_ratio</span> <span class="o">=</span> <span class="n">neg_pos_ratio</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">neg_pos_ratio</span> <span class="o">=</span> <span class="n">neg_pos_ratio</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ce</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ce</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="HardMiningCrossEntropyLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.HardMiningCrossEntropyLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred_labels</span><span class="p">,</span> <span class="n">target_labels</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred_labels</span><span class="p">,</span> <span class="n">target_labels</span><span class="p">):</span>
             <span class="n">mask</span> <span class="o">=</span> <span class="n">target_labels</span> <span class="o">&gt;</span> <span class="mi">0</span>  <span class="c1"># not background</span>
             <span class="n">mask</span> <span class="o">=</span> <span class="n">target_labels</span> <span class="o">&gt;</span> <span class="mi">0</span>  <span class="c1"># not background</span>
             <span class="n">pos_num</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
             <span class="n">pos_num</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
     
     
    @@ -135,10 +137,10 @@
             <span class="n">neg_mask</span> <span class="o">=</span> <span class="n">con_rank</span> <span class="o">&lt;</span> <span class="n">neg_num</span>
             <span class="n">neg_mask</span> <span class="o">=</span> <span class="n">con_rank</span> <span class="o">&lt;</span> <span class="n">neg_num</span>
     
     
             <span class="n">closs</span> <span class="o">=</span> <span class="p">(</span><span class="n">con</span> <span class="o">*</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">+</span> <span class="n">neg_mask</span><span class="o">.</span><span class="n">float</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</sp
             <span class="n">closs</span> <span class="o">=</span> <span class="p">(</span><span class="n">con</span> <span class="o">*</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">+</span> <span class="n">neg_mask</span><span class="o">.</span><span class="n">float</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</sp
    -        <span class="k">return</span> <span class="n">closs</span></div></div>
    +        <span class="k">return</span> <span class="n">closs</span>
     
     
     
     
    -<div class="viewcode-block" id="SSDLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.SSDLoss">[docs]</a><span class="k">class</span> <span class="nc">SSDLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    +<div class="viewcode-block" id="SSDLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss">[docs]</a><span class="k">class</span> <span class="nc">SSDLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Implements the loss as the sum of the followings:</span>
     <span class="sd">        Implements the loss as the sum of the followings:</span>
     <span class="sd">        1. Confidence Loss: All labels, with hard negative mining</span>
     <span class="sd">        1. Confidence Loss: All labels, with hard negative mining</span>
    @@ -167,6 +169,15 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span> <span class="o">=</span> <span class="n">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">neg_pos_ratio</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span> <span class="o">=</span> <span class="n">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">neg_pos_ratio</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh</span> <span class="o">=</span> <span class="n">iou_thresh</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh</span> <span class="o">=</span> <span class="n">iou_thresh</span>
     
     
    +    <span class="nd">@property</span>
    +    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Component names for logging during training.</span>
    +<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
    +<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;smooth_l1&quot;</span><span class="p">,</span> <span class="s2">&quot;closs&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss&quot;</span><span class="p">]</span>
    +
         <span class="k">def</span> <span class="nf">_norm_relative_bbox</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loc</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_norm_relative_bbox</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loc</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        convert bbox locations into relative locations (relative to the dboxes)</span>
     <span class="sd">        convert bbox locations into relative locations (relative to the dboxes)</span>
    @@ -176,7 +187,7 @@
             <span class="n">gwh</span> <span class="o">=</span> <span class="p">(</span><span class="n">loc</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:])</span><span class="o">.</span><span class="n">log</span><span class="p">()</span>
             <span class="n">gwh</span> <span class="o">=</span> <span class="p">(</span><span class="n">loc</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:])</span><span class="o">.</span><span class="n">log</span><span class="p">()</span>
             <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
             <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
     
     
    -<div class="viewcode-block" id="SSDLoss.match_dboxes"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.SSDLoss.match_dboxes">[docs]</a>    <span class="k">def</span> <span class="nf">match_dboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
    +<div class="viewcode-block" id="SSDLoss.match_dboxes"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss.match_dboxes">[docs]</a>    <span class="k">def</span> <span class="nf">match_dboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.</span>
     <span class="sd">        creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.</span>
     
     
    @@ -220,7 +231,7 @@
     
     
             <span class="k">return</span> <span class="n">each_cell_target_locations</span><span class="p">,</span> <span class="n">each_cell_target_labels</span></div>
             <span class="k">return</span> <span class="n">each_cell_target_locations</span><span class="p">,</span> <span class="n">each_cell_target_labels</span></div>
     
     
    -<div class="viewcode-block" id="SSDLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.SSDLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
    +<div class="viewcode-block" id="SSDLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Compute the loss</span>
     <span class="sd">        Compute the loss</span>
     <span class="sd">            :param predictions - predictions tensor coming from the network,</span>
     <span class="sd">            :param predictions - predictions tensor coming from the network,</span>
    @@ -289,4 +300,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.losses.yolo_v3_loss &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">SuperGradients</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common</a></li>
    40. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training</a></li>
    41. </ul>
    42. </div>
    43. </div>
    44. </nav>
    45. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    46. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    47. <a href="../../../../index.html">SuperGradients</a>
    48. </nav>
    49. <div class="wy-nav-content">
    50. <div class="rst-content">
    51. <div role="navigation" aria-label="Page navigation">
    52. <ul class="wy-breadcrumbs">
    53. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    54. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    55. <li>super_gradients.training.losses.yolo_v3_loss</li>
    56. <li class="wy-breadcrumbs-aside">
    57. </li>
    58. </ul>
    59. <hr/>
    60. </div>
    61. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    62. <div itemprop="articleBody">
    63. <h1>Source code for super_gradients.training.losses.yolo_v3_loss</h1><div class="highlight"><pre>
    64. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    65. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    66. <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
    67. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">build_detection_targets</span><span class="p">,</span> <span class="n">calculate_bbox_iou_elementwise</span>
    68. <div class="viewcode-block" id="YoLoV3DetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.yolo_v3_loss.YoLoV3DetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoLoV3DetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    69. <span class="sd">&quot;&quot;&quot;</span>
    70. <span class="sd"> YoLoV3DetectionLoss - Loss Class for Object Detection</span>
    71. <span class="sd"> &quot;&quot;&quot;</span>
    72. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">cls_pw</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">obj_pw</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">giou</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3.54</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">64.3</span><span class="p">,</span>
    73. <span class="bp">cls</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">37.4</span><span class="p">):</span>
    74. <span class="nb">super</span><span class="p">(</span><span class="n">YoLoV3DetectionLoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    75. <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
    76. <span class="bp">self</span><span class="o">.</span><span class="n">cls_pw</span> <span class="o">=</span> <span class="n">cls_pw</span>
    77. <span class="bp">self</span><span class="o">.</span><span class="n">obj_pw</span> <span class="o">=</span> <span class="n">obj_pw</span>
    78. <span class="bp">self</span><span class="o">.</span><span class="n">giou</span> <span class="o">=</span> <span class="n">giou</span>
    79. <span class="bp">self</span><span class="o">.</span><span class="n">obj</span> <span class="o">=</span> <span class="n">obj</span>
    80. <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
    81. <span class="bp">self</span><span class="o">.</span><span class="n">classes_num</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">num_classes</span>
    82. <div class="viewcode-block" id="YoLoV3DetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.yolo_v3_loss.YoLoV3DetectionLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
    83. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_output</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">model_output</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
    84. <span class="c1"># in test/eval mode the Yolo v3 model output a tuple where the second item is the raw predictions</span>
    85. <span class="n">_</span><span class="p">,</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
    86. <span class="k">else</span><span class="p">:</span>
    87. <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
    88. <span class="n">detection_targets</span> <span class="o">=</span> <span class="n">build_detection_targets</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
    89. <span class="n">float_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">FloatTensor</span> <span class="k">if</span> <span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">is_cuda</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span>
    90. <span class="n">class_loss</span><span class="p">,</span> <span class="n">giou_loss</span><span class="p">,</span> <span class="n">objectness_loss</span> <span class="o">=</span> <span class="n">float_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">]),</span> <span class="n">float_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">]),</span> <span class="n">float_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">])</span>
    91. <span class="n">target_class</span><span class="p">,</span> <span class="n">target_box</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchor_vec</span> <span class="o">=</span> <span class="n">detection_targets</span>
    92. <span class="n">reduction</span> <span class="o">=</span> <span class="s1">&#39;mean&#39;</span> <span class="c1"># Loss reduction (sum or mean)</span>
    93. <span class="c1"># DEFINE CRITERIA</span>
    94. <span class="n">BCEcls</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">float_tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">cls_pw</span><span class="p">]),</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
    95. <span class="n">BCEobj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">float_tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">obj_pw</span><span class="p">]),</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
    96. <span class="c1"># COMPUTE THE LOSSES BASED ON EACH ONE OF THE YOLO LAYERS PREDICTIONS</span>
    97. <span class="n">grid_points_num</span><span class="p">,</span> <span class="n">targets_num</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
    98. <span class="k">for</span> <span class="n">yolo_layer_index</span><span class="p">,</span> <span class="n">yolo_layer_prediction</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">predictions</span><span class="p">):</span>
    99. <span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">]</span>
    100. <span class="n">target_object</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">yolo_layer_prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
    101. <span class="n">grid_points_num</span> <span class="o">+=</span> <span class="n">target_object</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
    102. <span class="c1"># COMPUTE LOSSES</span>
    103. <span class="n">nb</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
    104. <span class="k">if</span> <span class="n">nb</span><span class="p">:</span> <span class="c1"># number of targets</span>
    105. <span class="n">targets_num</span> <span class="o">+=</span> <span class="n">nb</span>
    106. <span class="n">predictions_for_targets</span> <span class="o">=</span> <span class="n">yolo_layer_prediction</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span>
    107. <span class="n">target_object</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="o">=</span> <span class="mf">1.0</span>
    108. <span class="c1"># GIoU LOSS CALCULATION</span>
    109. <span class="n">pxy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span>
    110. <span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">])</span> <span class="c1"># pxy = pxy * s - (s - 1) / 2, s = 1.5 (scale_xy)</span>
    111. <span class="n">bbox_prediction</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
    112. <span class="p">(</span><span class="n">pxy</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">])</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">max</span><span class="o">=</span><span class="mf">1E3</span><span class="p">)</span> <span class="o">*</span> <span class="n">anchor_vec</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span>
    113. <span class="n">giou</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">calculate_bbox_iou_elementwise</span><span class="p">(</span><span class="n">bbox_prediction</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="n">target_box</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">],</span>
    114. <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">GIoU</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    115. <span class="n">giou_loss</span> <span class="o">+=</span> <span class="n">giou</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="k">if</span> <span class="n">reduction</span> <span class="o">==</span> <span class="s1">&#39;sum&#39;</span> <span class="k">else</span> <span class="n">giou</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
    116. <span class="c1"># ONLY RELEVANT TO MULTIPLE CLASSES</span>
    117. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes_num</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    118. <span class="n">class_targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:])</span>
    119. <span class="n">class_targets</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="n">nb</span><span class="p">),</span> <span class="n">target_class</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">]]</span> <span class="o">=</span> <span class="mf">1.0</span>
    120. <span class="n">class_loss</span> <span class="o">+=</span> <span class="n">BCEcls</span><span class="p">(</span><span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:],</span> <span class="n">class_targets</span><span class="p">)</span>
    121. <span class="n">objectness_loss</span> <span class="o">+=</span> <span class="n">BCEobj</span><span class="p">(</span><span class="n">yolo_layer_prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">target_object</span><span class="p">)</span>
    122. <span class="k">if</span> <span class="n">reduction</span> <span class="o">==</span> <span class="s1">&#39;sum&#39;</span><span class="p">:</span>
    123. <span class="n">giou_loss</span> <span class="o">*=</span> <span class="mi">3</span> <span class="o">/</span> <span class="n">targets_num</span>
    124. <span class="n">objectness_loss</span> <span class="o">*=</span> <span class="mi">3</span> <span class="o">/</span> <span class="n">grid_points_num</span>
    125. <span class="n">class_loss</span> <span class="o">*=</span> <span class="mi">3</span> <span class="o">/</span> <span class="n">targets_num</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes_num</span>
    126. <span class="n">loss</span> <span class="o">=</span> <span class="n">giou_loss</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">giou</span> <span class="o">+</span> <span class="n">objectness_loss</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">obj</span> <span class="o">+</span> <span class="n">class_loss</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls</span>
    127. <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">giou_loss</span><span class="p">,</span> <span class="n">objectness_loss</span><span class="p">,</span> <span class="n">class_loss</span><span class="p">,</span> <span class="n">loss</span><span class="p">))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
    128. </pre></div>
    129. </div>
    130. </div>
    131. <footer>
    132. <hr/>
    133. <div role="contentinfo">
    134. <p>&#169; Copyright 2021, SuperGradients team.</p>
    135. </div>
    136. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    137. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    138. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    139. </footer>
    140. </div>
    141. </div>
    142. </section>
    143. </div>
    144. <script>
    145. jQuery(function () {
    146. SphinxRtdTheme.Navigation.enable(true);
    147. });
    148. </script>
    149. </body>
    150. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.losses.yolo_v5_loss &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">SuperGradients</a></li>
    39. </ul>
    40. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    41. <ul>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    44. </ul>
    45. <p class="caption"><span class="caption-text">User Guide</span></p>
    46. <ul>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    58. </ul>
    59. </div>
    60. </div>
    61. </nav>
    62. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    63. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    64. <a href="../../../../index.html">SuperGradients</a>
    65. </nav>
    66. <div class="wy-nav-content">
    67. <div class="rst-content">
    68. <div role="navigation" aria-label="Page navigation">
    69. <ul class="wy-breadcrumbs">
    70. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    71. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    72. <li>super_gradients.training.losses.yolo_v5_loss</li>
    73. <li class="wy-breadcrumbs-aside">
    74. </li>
    75. </ul>
    76. <hr/>
    77. </div>
    78. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    79. <div itemprop="articleBody">
    80. <h1>Source code for super_gradients.training.losses.yolo_v5_loss</h1><div class="highlight"><pre>
    81. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
    82. <span class="kn">import</span> <span class="nn">torch</span>
    83. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    84. <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.training.losses.focal_loss</span> <span class="kn">import</span> <span class="n">FocalLoss</span>
    86. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_elementwise</span><span class="p">,</span> <span class="n">Anchors</span>
    87. <div class="viewcode-block" id="YoLoV5DetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoLoV5DetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    88. <span class="sd">&quot;&quot;&quot;</span>
    89. <span class="sd"> Calculate YOLO V5 loss:</span>
    90. <span class="sd"> L = L_objectivness + L_boxes + L_classification</span>
    91. <span class="sd"> &quot;&quot;&quot;</span>
    92. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">anchors</span><span class="p">:</span> <span class="n">Anchors</span><span class="p">,</span>
    93. <span class="n">cls_pos_weight</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">obj_pos_weight</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
    94. <span class="n">obj_loss_gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">box_loss_gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.05</span><span class="p">,</span> <span class="n">cls_loss_gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    95. <span class="n">focal_loss_gamma</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
    96. <span class="n">cls_objectness_weights</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">anchor_threshold</span><span class="o">=</span><span class="mf">4.0</span><span class="p">):</span>
    97. <span class="sd">&quot;&quot;&quot;</span>
    98. <span class="sd"> :param anchors: the anchors of the model (same anchors used for training)</span>
    99. <span class="sd"> :param cls_pos_weight: pos_weight for BCE in L_classification,</span>
    100. <span class="sd"> can be one value for all positives or a list of weights for each class</span>
    101. <span class="sd"> :param obj_pos_weight: pos_weight for BCE in L_objectivness</span>
    102. <span class="sd"> :param obj_loss_gain: coef for L_objectivness</span>
    103. <span class="sd"> :param box_loss_gain: coef for L_boxes</span>
    104. <span class="sd"> :param cls_loss_gain: coef for L_classification</span>
    105. <span class="sd"> :param focal_loss_gamma: gamma for a focal loss, 0 to train with a usual BCE</span>
    106. <span class="sd"> :param cls_objectness_weights: class-based weight for L_objectivness that will be applied in each cell that</span>
    107. <span class="sd"> has a GT assigned to it.</span>
    108. <span class="sd"> Note: default weight for objectness loss in each cell is 1.</span>
    109. <span class="sd"> :param anchor_threshold: ratio defining a size range of an appropriate anchor.</span>
    110. <span class="sd"> &quot;&quot;&quot;</span>
    111. <span class="nb">super</span><span class="p">(</span><span class="n">YoLoV5DetectionLoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    112. <span class="bp">self</span><span class="o">.</span><span class="n">cls_pos_weight</span> <span class="o">=</span> <span class="n">cls_pos_weight</span>
    113. <span class="bp">self</span><span class="o">.</span><span class="n">obj_pos_weight</span> <span class="o">=</span> <span class="n">obj_pos_weight</span>
    114. <span class="bp">self</span><span class="o">.</span><span class="n">obj_loss_gain</span> <span class="o">=</span> <span class="n">obj_loss_gain</span>
    115. <span class="bp">self</span><span class="o">.</span><span class="n">box_loss_gain</span> <span class="o">=</span> <span class="n">box_loss_gain</span>
    116. <span class="bp">self</span><span class="o">.</span><span class="n">cls_loss_gain</span> <span class="o">=</span> <span class="n">cls_loss_gain</span>
    117. <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span> <span class="o">=</span> <span class="n">focal_loss_gamma</span>
    118. <span class="bp">self</span><span class="o">.</span><span class="n">anchor_threshold</span> <span class="o">=</span> <span class="n">anchor_threshold</span>
    119. <span class="bp">self</span><span class="o">.</span><span class="n">anchors</span> <span class="o">=</span> <span class="n">anchors</span>
    120. <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span> <span class="o">=</span> <span class="n">cls_objectness_weights</span>
    121. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">cls_objectness_weights</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
    122. <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">cls_objectness_weights</span><span class="p">))</span>
    123. <div class="viewcode-block" id="YoLoV5DetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
    124. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_output</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">model_output</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
    125. <span class="c1"># in test/eval mode the Yolo v5 model output a tuple where the second item is the raw predictions</span>
    126. <span class="n">_</span><span class="p">,</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
    127. <span class="k">else</span><span class="p">:</span>
    128. <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
    129. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span></div>
    130. <div class="viewcode-block" id="YoLoV5DetectionLoss.build_targets"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss.build_targets">[docs]</a> <span class="k">def</span> <span class="nf">build_targets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> \
    131. <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]],</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]:</span>
    132. <span class="sd">&quot;&quot;&quot;</span>
    133. <span class="sd"> Assign targets to anchors to use in L_boxes &amp; L_classification calculation:</span>
    134. <span class="sd"> * each target can be assigned to a few anchors,</span>
    135. <span class="sd"> all anchors that are within [1/self.anchor_threshold, self.anchor_threshold] times target size range</span>
    136. <span class="sd"> * each anchor can be assigned to a few targets</span>
    137. <span class="sd"> :param predictions: Yolo predictions</span>
    138. <span class="sd"> :param targets: ground truth targets</span>
    139. <span class="sd"> :return: each of 4 outputs contains one element for each Yolo output,</span>
    140. <span class="sd"> correspondences are raveled over the whole batch and all anchors:</span>
    141. <span class="sd"> * classes of the targets;</span>
    142. <span class="sd"> * boxes of the targets;</span>
    143. <span class="sd"> * image id in a batch, anchor id, grid y, grid x coordinates;</span>
    144. <span class="sd"> * anchor sizes.</span>
    145. <span class="sd"> All the above can be indexed in parallel to get the selected correspondences</span>
    146. <span class="sd"> &quot;&quot;&quot;</span>
    147. <span class="n">num_anchors</span><span class="p">,</span> <span class="n">num_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">num_anchors</span><span class="p">,</span> <span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    148. <span class="n">target_classes</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchors</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
    149. <span class="n">gain</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">targets</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="c1"># normalized to gridspace gain</span>
    150. <span class="n">anchor_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_anchors</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">targets</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    151. <span class="n">anchor_indices</span> <span class="o">=</span> <span class="n">anchor_indices</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_anchors</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_targets</span><span class="p">)</span>
    152. <span class="c1"># repeat all targets for each anchor and append a corresponding anchor index</span>
    153. <span class="n">targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">targets</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">num_anchors</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">anchor_indices</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]),</span> <span class="mi">2</span><span class="p">)</span>
    154. <span class="n">bias</span> <span class="o">=</span> <span class="mf">0.5</span>
    155. <span class="n">off</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
    156. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="c1"># j,k,l,m</span>
    157. <span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">targets</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">*</span> <span class="n">bias</span> <span class="c1"># offsets</span>
    158. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">detection_layers_num</span><span class="p">):</span>
    159. <span class="n">anch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">anchors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    160. <span class="n">gain</span><span class="p">[</span><span class="mi">2</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)[[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="c1"># xyxy gain</span>
    161. <span class="c1"># Convert target coordinates from [0, 1] range to coordinates in [0, GridY], [0, GridX] ranges</span>
    162. <span class="n">t</span> <span class="o">=</span> <span class="n">targets</span> <span class="o">*</span> <span class="n">gain</span>
    163. <span class="k">if</span> <span class="n">num_targets</span><span class="p">:</span>
    164. <span class="c1"># Match: filter targets by anchor size ratio</span>
    165. <span class="n">r</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="o">/</span> <span class="n">anch</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="c1"># wh ratio</span>
    166. <span class="n">filtered_targets_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">r</span><span class="p">,</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">r</span><span class="p">)</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">2</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">anchor_threshold</span> <span class="c1"># compare</span>
    167. <span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">filtered_targets_ids</span><span class="p">]</span>
    168. <span class="c1"># Find coordinates of targets on a grid</span>
    169. <span class="n">gxy</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="c1"># grid xy</span>
    170. <span class="n">gxi</span> <span class="o">=</span> <span class="n">gain</span><span class="p">[[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">-</span> <span class="n">gxy</span> <span class="c1"># inverse</span>
    171. <span class="n">j</span><span class="p">,</span> <span class="n">k</span> <span class="o">=</span> <span class="p">((</span><span class="n">gxy</span> <span class="o">%</span> <span class="mf">1.</span> <span class="o">&lt;</span> <span class="n">bias</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">gxy</span> <span class="o">&gt;</span> <span class="mf">1.</span><span class="p">))</span><span class="o">.</span><span class="n">T</span>
    172. <span class="n">l</span><span class="p">,</span> <span class="n">m</span> <span class="o">=</span> <span class="p">((</span><span class="n">gxi</span> <span class="o">%</span> <span class="mf">1.</span> <span class="o">&lt;</span> <span class="n">bias</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">gxi</span> <span class="o">&gt;</span> <span class="mf">1.</span><span class="p">))</span><span class="o">.</span><span class="n">T</span>
    173. <span class="n">j</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">j</span><span class="p">),</span> <span class="n">j</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span> <span class="n">m</span><span class="p">))</span>
    174. <span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">repeat</span><span class="p">((</span><span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))[</span><span class="n">j</span><span class="p">]</span>
    175. <span class="n">offsets</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">gxy</span><span class="p">)[</span><span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">off</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])[</span><span class="n">j</span><span class="p">]</span>
    176. <span class="k">else</span><span class="p">:</span>
    177. <span class="n">t</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    178. <span class="n">offsets</span> <span class="o">=</span> <span class="mi">0</span>
    179. <span class="c1"># Define</span>
    180. <span class="n">b</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">T</span> <span class="c1"># image, class</span>
    181. <span class="n">gxy</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="c1"># grid xy</span>
    182. <span class="n">gwh</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="c1"># grid wh</span>
    183. <span class="n">gij</span> <span class="o">=</span> <span class="p">(</span><span class="n">gxy</span> <span class="o">-</span> <span class="n">offsets</span><span class="p">)</span><span class="o">.</span><span class="n">long</span><span class="p">()</span>
    184. <span class="n">gi</span><span class="p">,</span> <span class="n">gj</span> <span class="o">=</span> <span class="n">gij</span><span class="o">.</span><span class="n">T</span> <span class="c1"># grid xy indices</span>
    185. <span class="c1"># prevent coordinates from going out of bounds</span>
    186. <span class="n">gi</span><span class="p">,</span> <span class="n">gj</span> <span class="o">=</span> <span class="n">gi</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">gain</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="n">gj</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">gain</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
    187. <span class="c1"># Append</span>
    188. <span class="n">a</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">6</span><span class="p">]</span><span class="o">.</span><span class="n">long</span><span class="p">()</span> <span class="c1"># anchor indices</span>
    189. <span class="n">indices</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">b</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">gj</span><span class="p">,</span> <span class="n">gi</span><span class="p">))</span> <span class="c1"># image, anchor, grid indices</span>
    190. <span class="n">target_boxes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span> <span class="o">-</span> <span class="n">gij</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># box</span>
    191. <span class="n">anchors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">anch</span><span class="p">[</span><span class="n">a</span><span class="p">])</span> <span class="c1"># anchors</span>
    192. <span class="n">target_classes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> <span class="c1"># class</span>
    193. <span class="k">return</span> <span class="n">target_classes</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchors</span></div>
    194. <div class="viewcode-block" id="YoLoV5DetectionLoss.compute_loss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss.compute_loss">[docs]</a> <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">giou_loss_ratio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> \
    195. <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
    196. <span class="sd">&quot;&quot;&quot;</span>
    197. <span class="sd"> L = L_objectivness + L_boxes + L_classification</span>
    198. <span class="sd"> where:</span>
    199. <span class="sd"> * L_boxes and L_classification are calculated only between anchors and targets that suit them;</span>
    200. <span class="sd"> * L_objectivness is calculated on all anchors.</span>
    201. <span class="sd"> L_classification:</span>
    202. <span class="sd"> for anchors that have suitable ground truths in their grid locations add BCEs</span>
    203. <span class="sd"> to force max probability for each GT class in a multi-label way</span>
    204. <span class="sd"> Coef: self.cls_loss_gain</span>
    205. <span class="sd"> L_boxes:</span>
    206. <span class="sd"> for anchors that have suitable ground truths in their grid locations</span>
    207. <span class="sd"> add (1 - IoU), IoU between a predicted box and each GT box, force maximum IoU</span>
    208. <span class="sd"> Coef: self.box_loss_gain</span>
    209. <span class="sd"> L_objectness:</span>
    210. <span class="sd"> for each anchor add BCE to force a prediction of (1 - giou_loss_ratio) + giou_loss_ratio * IoU,</span>
    211. <span class="sd"> IoU between a predicted box and random GT in it</span>
    212. <span class="sd"> Coef: self.obj_loss_gain, loss from each YOLO grid is additionally multiplied by balance = [4.0, 1.0, 0.4]</span>
    213. <span class="sd"> to balance different contributions coming from different numbers of grid cells</span>
    214. <span class="sd"> :param predictions: output from all Yolo levels, each of shape</span>
    215. <span class="sd"> [Batch x Num_Anchors x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
    216. <span class="sd"> :param targets: [Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h</span>
    217. <span class="sd"> :param giou_loss_ratio: a coef in L_objectness defining what should be predicted as objecness</span>
    218. <span class="sd"> in a call with a target: can be a value in [IoU, 1] range</span>
    219. <span class="sd"> :return: loss, all losses separately in a detached tensor</span>
    220. <span class="sd"> &quot;&quot;&quot;</span>
    221. <span class="n">device</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">device</span>
    222. <span class="n">loss_classification</span><span class="p">,</span> <span class="n">loss_boxes</span><span class="p">,</span> <span class="n">loss_objectivness</span> <span class="o">=</span> \
    223. <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    224. <span class="n">target_classes</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_targets</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="c1"># targets</span>
    225. <span class="c1"># Define criteria</span>
    226. <span class="n">BCEcls</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">cls_pos_weight</span><span class="p">]))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    227. <span class="n">BCEobj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">obj_pos_weight</span><span class="p">]),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    228. <span class="c1"># Focal loss</span>
    229. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    230. <span class="n">BCEcls</span><span class="p">,</span> <span class="n">BCEobj</span> <span class="o">=</span> <span class="n">FocalLoss</span><span class="p">(</span><span class="n">BCEcls</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span><span class="p">),</span> <span class="n">FocalLoss</span><span class="p">(</span><span class="n">BCEobj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span><span class="p">)</span>
    231. <span class="c1"># Losses</span>
    232. <span class="n">num_targets</span> <span class="o">=</span> <span class="mi">0</span>
    233. <span class="n">num_predictions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
    234. <span class="n">balance</span> <span class="o">=</span> <span class="p">[</span><span class="mf">4.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]</span> <span class="k">if</span> <span class="n">num_predictions</span> <span class="o">==</span> <span class="mi">3</span> <span class="k">else</span> <span class="p">[</span><span class="mf">4.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]</span> <span class="c1"># P3-5 or P3-6</span>
    235. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">prediction</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">predictions</span><span class="p">):</span> <span class="c1"># layer index, layer predictions</span>
    236. <span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    237. <span class="n">target_obj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    238. <span class="n">weight_obj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    239. <span class="n">n</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># number of targets</span>
    240. <span class="k">if</span> <span class="n">n</span><span class="p">:</span>
    241. <span class="n">num_targets</span> <span class="o">+=</span> <span class="n">n</span> <span class="c1"># cumulative targets</span>
    242. <span class="n">ps</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="c1"># prediction subset corresponding to targets</span>
    243. <span class="c1"># Boxes loss</span>
    244. <span class="n">pxy</span> <span class="o">=</span> <span class="n">ps</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">()</span> <span class="o">*</span> <span class="mf">2.</span> <span class="o">-</span> <span class="mf">0.5</span>
    245. <span class="n">pwh</span> <span class="o">=</span> <span class="p">(</span><span class="n">ps</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">()</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">anchors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    246. <span class="n">pbox</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">pxy</span><span class="p">,</span> <span class="n">pwh</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="c1"># predicted box</span>
    247. <span class="n">iou</span> <span class="o">=</span> <span class="n">calculate_bbox_iou_elementwise</span><span class="p">(</span><span class="n">pbox</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">CIoU</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    248. <span class="n">loss_boxes</span> <span class="o">+=</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">iou</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="c1"># iou loss</span>
    249. <span class="c1"># Objectness loss target</span>
    250. <span class="n">target_obj</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="o">=</span> \
    251. <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">giou_loss_ratio</span><span class="p">)</span> <span class="o">+</span> <span class="n">giou_loss_ratio</span> <span class="o">*</span> <span class="n">iou</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">target_obj</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
    252. <span class="c1"># Weights for weighted objectness</span>
    253. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    254. <span class="c1"># NOTE: for grid cells that have a few ground truths with different classes assigned to them</span>
    255. <span class="c1"># objectness weight will be picked randomly from one of these classes</span>
    256. <span class="n">weight_obj</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span><span class="p">[</span><span class="n">target_classes</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
    257. <span class="c1"># Classification loss</span>
    258. <span class="k">if</span> <span class="n">ps</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">6</span><span class="p">:</span> <span class="c1"># cls loss (only if multiple classes)</span>
    259. <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full_like</span><span class="p">(</span><span class="n">ps</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:],</span> <span class="mi">0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="c1"># targets</span>
    260. <span class="n">t</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">),</span> <span class="n">target_classes</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span>
    261. <span class="n">loss_classification</span> <span class="o">+=</span> <span class="n">BCEcls</span><span class="p">(</span><span class="n">ps</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:],</span> <span class="n">t</span><span class="p">)</span> <span class="c1"># BCE</span>
    262. <span class="c1"># Objectness loss</span>
    263. <span class="n">loss_obj_cur_head</span> <span class="o">=</span> <span class="n">BCEobj</span><span class="p">(</span><span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">target_obj</span><span class="p">)</span>
    264. <span class="n">loss_obj_cur_head</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">loss_obj_cur_head</span> <span class="o">*</span> <span class="n">weight_obj</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">weight_obj</span><span class="p">))</span>
    265. <span class="n">loss_objectivness</span> <span class="o">+=</span> <span class="n">loss_obj_cur_head</span> <span class="o">*</span> <span class="n">balance</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="c1"># obj loss</span>
    266. <span class="n">batch_size</span> <span class="o">=</span> <span class="n">prediction</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># batch size</span>
    267. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_boxes</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">box_loss_gain</span> <span class="o">+</span> <span class="n">loss_objectivness</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">obj_loss_gain</span> <span class="o">+</span> <span class="n">loss_classification</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_loss_gain</span>
    268. <span class="c1"># IMPORTANT: box, obj and cls loss are logged scaled by gain in ultralytics</span>
    269. <span class="c1"># and are logged unscaled in our codebase</span>
    270. <span class="k">return</span> <span class="n">loss</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">loss_boxes</span><span class="p">,</span> <span class="n">loss_objectivness</span><span class="p">,</span> <span class="n">loss_classification</span><span class="p">,</span> <span class="n">loss</span><span class="p">))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
    271. </pre></div>
    272. </div>
    273. </div>
    274. <footer>
    275. <hr/>
    276. <div role="contentinfo">
    277. <p>&#169; Copyright 2021, SuperGradients team.</p>
    278. </div>
    279. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    280. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    281. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    282. </footer>
    283. </div>
    284. </div>
    285. </section>
    286. </div>
    287. <script>
    288. jQuery(function () {
    289. SphinxRtdTheme.Navigation.enable(true);
    290. });
    291. </script>
    292. </body>
    293. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.losses.yolox_loss &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.losses.yolox_loss &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -95,10 +97,13 @@
     <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
     <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
     
     
     <span class="kn">import</span> <span class="nn">torch</span>
     <span class="kn">import</span> <span class="nn">torch</span>
    +<span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
     <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
     <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
     <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
     <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
     <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
     <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
    +
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">torch_version_is_greater_or_equal</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_matrix</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_matrix</span>
     
     
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    @@ -127,10 +132,9 @@
             <span class="n">supported_losses</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;iou&quot;</span><span class="p">,</span> <span class="s2">&quot;giou&quot;</span><span class="p">]</span>
             <span class="n">supported_losses</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;iou&quot;</span><span class="p">,</span> <span class="s2">&quot;giou&quot;</span><span class="p">]</span>
             <span class="n">supported_reductions</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;mean&quot;</span><span class="p">,</span> <span class="s2">&quot;sum&quot;</span><span class="p">,</span> <span class="s2">&quot;none&quot;</span><span class="p">]</span>
             <span class="n">supported_reductions</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;mean&quot;</span><span class="p">,</span> <span class="s2">&quot;sum&quot;</span><span class="p">,</span> <span class="s2">&quot;none&quot;</span><span class="p">]</span>
             <span class="k">if</span> <span class="n">loss_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_losses</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">loss_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_losses</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal loss_type value: &quot;</span> <span class="o">+</span> <span class="n">loss_type</span> <span class="o">+</span> <span class="s1">&#39;, expected one of: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_losses</span><span class="p">))</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal loss_type value: &quot;</span> <span class="o">+</span> <span class="n">loss_type</span> <span class="o">+</span> <span class="s2">&quot;, expected one of: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_losses</span><span class="p">))</span>
             <span class="k">if</span> <span class="n">reduction</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_reductions</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">reduction</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_reductions</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
    -                <span class="s2">&quot;Illegal reduction value: &quot;</span> <span class="o">+</span> <span class="n">reduction</span> <span class="o">+</span> <span class="s1">&#39;, expected one of: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_reductions</span><span class="p">))</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal reduction value: &quot;</span> <span class="o">+</span> <span class="n">reduction</span> <span class="o">+</span> <span class="s2">&quot;, expected one of: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_reductions</span><span class="p">))</span>
     
     
         <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
             <span class="k">assert</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
             <span class="k">assert</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    @@ -149,7 +153,7 @@
             <span class="n">iou</span> <span class="o">=</span> <span class="p">(</span><span class="n">area_i</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">area_u</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span>
             <span class="n">iou</span> <span class="o">=</span> <span class="p">(</span><span class="n">area_i</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">area_u</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span>
     
     
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;iou&quot;</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;iou&quot;</span><span class="p">:</span>
    -            <span class="n">loss</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">iou</span> <span class="o">**</span> <span class="mi">2</span>
    +            <span class="n">loss</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">iou</span><span class="o">**</span><span class="mi">2</span>
             <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;giou&quot;</span><span class="p">:</span>
             <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;giou&quot;</span><span class="p">:</span>
                 <span class="n">c_tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
                 <span class="n">c_tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
                 <span class="n">c_br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
                 <span class="n">c_br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
    @@ -165,7 +169,7 @@
             <span class="k">return</span> <span class="n">loss</span>
             <span class="k">return</span> <span class="n">loss</span>
     
     
     
     
    -<div class="viewcode-block" id="YoloXDetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoloXDetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
    +<div class="viewcode-block" id="YoloXDetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoloXDetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Calculate YOLOX loss:</span>
     <span class="sd">    Calculate YOLOX loss:</span>
     <span class="sd">    L = L_objectivness + L_iou + L_classification + 1[use_l1]*L_l1</span>
     <span class="sd">    L = L_objectivness + L_iou + L_classification + 1[use_l1]*L_l1</span>
    @@ -200,8 +204,7 @@
     
     
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">strides</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">use_l1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><sp
    -                 <span class="n">iou_type</span><span class="o">=</span><span class="s1">&#39;iou&#39;</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">strides</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">use_l1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><sp
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">grids</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">strides</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">grids</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">strides</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">strides</span> <span class="o">=</span> <span class="n">strides</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">strides</span> <span class="o">=</span> <span class="n">strides</span>
    @@ -213,7 +216,16 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">iou_loss</span> <span class="o">=</span> <span class="n">IOUloss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="n">iou_type</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">iou_loss</span> <span class="o">=</span> <span class="n">IOUloss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="n">iou_type</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="YoloXDetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span clas
    +    <span class="nd">@property</span>
    +    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Component names for logging during training.</span>
    +<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
    +<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;iou&quot;</span><span class="p">,</span> <span class="s2">&quot;obj&quot;</span><span class="p">,</span> <span class="s2">&quot;cls&quot;</span><span class="p">,</span> <span class="s2">&quot;l1&quot;</span><span class="p">,</span> <span class="s2">&quot;num_fg&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss&quot;</span><span class="p">]</span>
    +
    +<div class="viewcode-block" id="YoloXDetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        :param model_output: Union[list, Tuple[torch.Tensor, List]]:</span>
     <span class="sd">        :param model_output: Union[list, Tuple[torch.Tensor, List]]:</span>
     <span class="sd">             When list-</span>
     <span class="sd">             When list-</span>
    @@ -241,11 +253,14 @@
     <span class="sd">        :param ny: int: cells along the y axis (default=20)</span>
     <span class="sd">        :param ny: int: cells along the y axis (default=20)</span>
     <span class="sd">        :return: torch.tensor of xy coordinates of size (1,1,nx,ny,2)</span>
     <span class="sd">        :return: torch.tensor of xy coordinates of size (1,1,nx,ny,2)</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">yv</span><span class="p">,</span> <span class="n">xv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ny</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">nx</s
    +        <span class="k">if</span> <span class="n">torch_version_is_greater_or_equal</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">):</span>
    +            <span class="c1"># https://github.com/pytorch/pytorch/issues/50276</span>
    +            <span class="n">yv</span><span class="p">,</span> <span class="n">xv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ny</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="n">yv</span><span class="p">,</span> <span class="n">xv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ny</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n
             <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">xv</span><span class="p">,</span> <span class="n">yv</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">ny</span><span
             <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">xv</span><span class="p">,</span> <span class="n">yv</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">ny</span><span
     
     
    -    <span class="k">def</span> <span class="nf">_compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
    -            <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
    +    <span class="k">def</span> <span class="nf">_compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
     <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
     <span class="sd">                                [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
     <span class="sd">                                [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
    @@ -267,7 +282,7 @@
             <span class="n">obj_targets</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">obj_targets</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">fg_masks</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">fg_masks</span> <span class="o">=</span> <span class="p">[]</span>
     
     
    -        <span class="n">num_fg</span><span class="p">,</span> <span class="n">num_gts</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="mf">0.</span>
    +        <span class="n">num_fg</span><span class="p">,</span> <span class="n">num_gts</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span>
     
     
             <span class="k">for</span> <span class="n">image_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
             <span class="k">for</span> <span class="n">image_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
                 <span class="n">labels_im</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">image_idx</span><span class="p">]</span>
                 <span class="n">labels_im</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">image_idx</span><span class="p">]</span>
    @@ -287,21 +302,42 @@
     
     
                     <span class="k">try</span><span class="p">:</span>
                     <span class="k">try</span><span class="p">:</span>
                         <span class="c1"># assign cells to ground truths, at most one GT per cell</span>
                         <span class="c1"># assign cells to ground truths, at most one GT per cell</span>
    -                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> \
    -                        <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span><span class="n">image_idx</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
    -                                             <span class="n">gt_classes</span><span class="p">,</span> <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
    -                                             <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span> <span class="n">obj_preds</span><span class="p">)</span>
    +                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span>
    +                        <span class="n">image_idx</span><span class="p">,</span>
    +                        <span class="n">num_gt</span><span class="p">,</span>
    +                        <span class="n">total_num_anchors</span><span class="p">,</span>
    +                        <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
    +                        <span class="n">gt_classes</span><span class="p">,</span>
    +                        <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
    +                        <span class="n">expanded_strides</span><span class="p">,</span>
    +                        <span class="n">x_shifts</span><span class="p">,</span>
    +                        <span class="n">y_shifts</span><span class="p">,</span>
    +                        <span class="n">cls_preds</span><span class="p">,</span>
    +                        <span class="n">obj_preds</span><span class="p">,</span>
    +                    <span class="p">)</span>
     
     
                     <span class="c1"># TODO: CHECK IF ERROR IS CUDA OUT OF MEMORY</span>
                     <span class="c1"># TODO: CHECK IF ERROR IS CUDA OUT OF MEMORY</span>
                     <span class="k">except</span> <span class="ne">RuntimeError</span><span class="p">:</span>
                     <span class="k">except</span> <span class="ne">RuntimeError</span><span class="p">:</span>
    -                    <span class="n">logging</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;OOM RuntimeError is raised due to the huge memory cost during label assignment. </span><span class="se">\</span>
    +                    <span class="n">logging</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
    +                        <span class="s2">&quot;OOM RuntimeError is raised due to the huge memory cost during label assignment. </span><span class="se">\</span>
     <span class="s2">                                   CPU mode is applied in this batch. If you want to avoid this issue, </span><span class="se">\</span>
     <span class="s2">                                   CPU mode is applied in this batch. If you want to avoid this issue, </span><span class="se">\</span>
    -<span class="s2">                                   try to reduce the batch size or image size.&quot;</span><span class="p">)</span>
    +<span class="s2">                                   try to reduce the batch size or image size.&quot;</span>
    +                    <span class="p">)</span>
                         <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
                         <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
    -                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> \
    -                        <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span><span class="n">image_idx</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
    -                                             <span class="n">gt_classes</span><span class="p">,</span> <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
    -                                             <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span> <span class="n">obj_preds</span><span class="p">,</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
    +                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span>
    +                        <span class="n">image_idx</span><span class="p">,</span>
    +                        <span class="n">num_gt</span><span class="p">,</span>
    +                        <span class="n">total_num_anchors</span><span class="p">,</span>
    +                        <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
    +                        <span class="n">gt_classes</span><span class="p">,</span>
    +                        <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
    +                        <span class="n">expanded_strides</span><span class="p">,</span>
    +                        <span class="n">x_shifts</span><span class="p">,</span>
    +                        <span class="n">y_shifts</span><span class="p">,</span>
    +                        <span class="n">cls_preds</span><span class="p">,</span>
    +                        <span class="n">obj_preds</span><span class="p">,</span>
    +                        <span class="s2">&quot;cpu&quot;</span><span class="p">,</span>
    +                    <span class="p">)</span>
     
     
                     <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
                     <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
                     <span class="n">num_fg</span> <span class="o">+=</span> <span class="n">num_fg_img</span>
                     <span class="n">num_fg</span> <span class="o">+=</span> <span class="n">num_fg_img</span>
    @@ -310,9 +346,13 @@
                     <span class="n">obj_target</span> <span class="o">=</span> <span class="n">fg_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
                     <span class="n">obj_target</span> <span class="o">=</span> <span class="n">fg_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
                     <span class="n">reg_target</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">]</span>
                     <span class="n">reg_target</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">]</span>
                     <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
                     <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
    -                    <span class="n">l1_target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_l1_target</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">num_fg_img</span><span class="p">,</span> <span class="mi">4</span><span class="p">)),</span>
    -                                                   <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">],</span> <span class="n">expanded_strides</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
    -                                                   <span class="n">x_shifts</span><span class="o">=</span><span class="n">x_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span> <span class="n">y_shifts</span><span class="o">=</span><span class="n">y_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">])</span>
    +                    <span class="n">l1_target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_l1_target</span><span class="p">(</span>
    +                        <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">num_fg_img</span><span class="p">,</span> <span class="mi">4</span><span class="p">)),</span>
    +                        <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">],</span>
    +                        <span class="n">expanded_strides</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
    +                        <span class="n">x_shifts</span><span class="o">=</span><span class="n">x_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
    +                        <span class="n">y_shifts</span><span class="o">=</span><span class="n">y_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
    +                    <span class="p">)</span>
     
     
                 <span class="c1"># collect targets for all loss terms over the whole batch</span>
                 <span class="c1"># collect targets for all loss terms over the whole batch</span>
                 <span class="n">cls_targets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cls_target</span><span class="p">)</span>
                 <span class="n">cls_targets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cls_target</span><span class="p">)</span>
    @@ -343,13 +383,21 @@
             <span class="n">reg_weight</span> <span class="o">=</span> <span class="mf">5.0</span>
             <span class="n">reg_weight</span> <span class="o">=</span> <span class="mf">5.0</span>
             <span class="n">loss</span> <span class="o">=</span> <span class="n">reg_weight</span> <span class="o">*</span> <span class="n">loss_iou</span> <span class="o">+</span> <span class="n">loss_obj</span> <span class="o">+</span> <span class="n">loss_cls</span> <span class="o">+</span> <span class="n">loss_l1</span>
             <span class="n">loss</span> <span class="o">=</span> <span class="n">reg_weight</span> <span class="o">*</span> <span class="n">loss_iou</span> <span class="o">+</span> <span class="n">loss_obj</span> <span class="o">+</span> <span class="n">loss_cls</span> <span class="o">+</span> <span class="n">loss_l1</span>
     
     
    -        <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">loss_iou</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">loss_obj</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span cl
    -                                <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">loss_l1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
    -                                <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">num_fg</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">num_gts</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</spa
    -                                <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
    -
    -<div class="viewcode-block" id="YoloXDetectionLoss.prepare_predictions"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.prepare_predictions">[docs]</a>    <span class="k">def</span> <span class="nf">prepare_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><spa
    -            <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class
    +        <span class="k">return</span> <span class="p">(</span>
    +            <span class="n">loss</span><span class="p">,</span>
    +            <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
    +                <span class="p">(</span>
    +                    <span class="n">loss_iou</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                    <span class="n">loss_obj</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                    <span class="n">loss_cls</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">loss_l1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
    +                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">num_fg</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">num_gts</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span clas
    +                    <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                <span class="p">)</span>
    +            <span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span>
    +        <span class="p">)</span>
    +
    +<div class="viewcode-block" id="YoloXDetectionLoss.prepare_predictions"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.prepare_predictions">[docs]</a>    <span class="k">def</span> <span class="nf">prepare_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Convert raw outputs of the network into a format that merges outputs from all levels</span>
     <span class="sd">        Convert raw outputs of the network into a format that merges outputs from all levels</span>
     <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
     <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
    @@ -414,7 +462,7 @@
     
     
             <span class="k">return</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">transformed_outputs</span><span class="p">,</span> <span class="n">raw_outputs</span></div>
             <span class="k">return</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">transformed_outputs</span><span class="p">,</span> <span class="n">raw_outputs</span></div>
     
     
    -<div class="viewcode-block" id="YoloXDetectionLoss.get_l1_target"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.get_l1_target">[docs]</a>    <span class="k">def</span> <span class="nf">get_l1_target</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">l1_target</span><span class="p">,</span> <span class="n">gt</span><span class="p">,</span> <span class="n">stride</s
    +<div class="viewcode-block" id="YoloXDetectionLoss.get_l1_target"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.get_l1_target">[docs]</a>    <span class="k">def</span> <span class="nf">get_l1_target</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">l1_target</span><span class="p">,</span> <span class="n">gt</span><span class="p">,</span> <span class="n">stride</span><sp
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        :param l1_target:   tensor of zeros of shape [Num_cell_gt_pairs x 4]</span>
     <span class="sd">        :param l1_target:   tensor of zeros of shape [Num_cell_gt_pairs x 4]</span>
     <span class="sd">        :param gt:          targets in coordinates [Num_cell_gt_pairs x (4 + 1 + num_classes)]</span>
     <span class="sd">        :param gt:          targets in coordinates [Num_cell_gt_pairs x (4 + 1 + num_classes)]</span>
    @@ -427,10 +475,24 @@
             <span class="n">l1_target</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">gt</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="n">stride</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
             <span class="n">l1_target</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">gt</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="n">stride</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
             <span class="k">return</span> <span class="n">l1_target</span></div>
             <span class="k">return</span> <span class="n">l1_target</span></div>
     
     
    -<div class="viewcode-block" id="YoloXDetectionLoss.get_assignments"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.get_assignments">[docs]</a>    <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
    -    <span class="k">def</span> <span class="nf">get_assignments</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_idx</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">gt_classes</span><span class="p">,</span>
    -                        <span class="n">bboxes_preds_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span>
    -                        <span class="n">obj_preds</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;gpu&quot;</span><span class="p">,</span> <span class="n">ious_loss_cost_coeff</span><span class="o">=</span><span class="mf">3.0</span><span class="p">,</span> <span class="n">outside_boxes_and_center_cost_coeff</span><span class="o">=</span><span class="mf">100000.0</span><span class="p">):</span>
    +<div class="viewcode-block" id="YoloXDetectionLoss.get_assignments"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.get_assignments">[docs]</a>    <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
    +    <span class="k">def</span> <span class="nf">get_assignments</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">image_idx</span><span class="p">,</span>
    +        <span class="n">num_gt</span><span class="p">,</span>
    +        <span class="n">total_num_anchors</span><span class="p">,</span>
    +        <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
    +        <span class="n">gt_classes</span><span class="p">,</span>
    +        <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
    +        <span class="n">expanded_strides</span><span class="p">,</span>
    +        <span class="n">x_shifts</span><span class="p">,</span>
    +        <span class="n">y_shifts</span><span class="p">,</span>
    +        <span class="n">cls_preds</span><span class="p">,</span>
    +        <span class="n">obj_preds</span><span class="p">,</span>
    +        <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;gpu&quot;</span><span class="p">,</span>
    +        <span class="n">ious_loss_cost_coeff</span><span class="o">=</span><span class="mf">3.0</span><span class="p">,</span>
    +        <span class="n">outside_boxes_and_center_cost_coeff</span><span class="o">=</span><span class="mf">100000.0</span><span class="p">,</span>
    +    <span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Match cells to ground truth:</span>
     <span class="sd">        Match cells to ground truth:</span>
     <span class="sd">            * at most 1 GT per cell</span>
     <span class="sd">            * at most 1 GT per cell</span>
    @@ -464,8 +526,7 @@
                 <span class="n">y_shifts</span> <span class="o">=</span> <span class="n">y_shifts</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
                 <span class="n">y_shifts</span> <span class="o">=</span> <span class="n">y_shifts</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
     
     
             <span class="c1"># create a mask for foreground cells</span>
             <span class="c1"># create a mask for foreground cells</span>
    -        <span class="n">fg_mask</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_in_boxes_info</span><span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span>
    -                                                                 <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">)</span>
    +        <span class="n">fg_mask</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_in_boxes_info</span><span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span 
     
     
             <span class="n">bboxes_preds_per_image</span> <span class="o">=</span> <span class="n">bboxes_preds_per_image</span><span class="p">[</span><span class="n">fg_mask</span><span class="p">]</span>
             <span class="n">bboxes_preds_per_image</span> <span class="o">=</span> <span class="n">bboxes_preds_per_image</span><span class="p">[</span><span class="n">fg_mask</span><span class="p">]</span>
             <span class="n">cls_preds_</span> <span class="o">=</span> <span class="n">cls_preds</span><span class="p">[</span><span class="n">image_idx</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">]</span>
             <span class="n">cls_preds_</span> <span class="o">=</span> <span class="n">cls_preds</span><span class="p">[</span><span class="n">image_idx</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">]</span>
    @@ -490,12 +551,10 @@
                 <span class="n">pair_wise_cls_loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">cls_preds_</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">(),</span> <span class="n">gt_cls_per_image</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span><span cla
                 <span class="n">pair_wise_cls_loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">cls_preds_</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">(),</span> <span class="n">gt_cls_per_image</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span><span cla
             <span class="k">del</span> <span class="n">cls_preds_</span>
             <span class="k">del</span> <span class="n">cls_preds_</span>
     
     
    -        <span class="n">cost</span> <span class="o">=</span> <span class="n">pair_wise_cls_loss</span> <span class="o">+</span> <span class="n">ious_loss_cost_coeff</span> <span class="o">*</span> <span class="n">pair_wise_ious_loss</span> <span class="o">+</span> <span class="n">outside_boxes_and_center_cost_coeff</span> <span class="o">*</span> <span class="p">(</span>
    -            <span class="o">~</span><span class="n">is_in_boxes_and_center</span><span class="p">)</span>
    +        <span class="n">cost</span> <span class="o">=</span> <span class="n">pair_wise_cls_loss</span> <span class="o">+</span> <span class="n">ious_loss_cost_coeff</span> <span class="o">*</span> <span class="n">pair_wise_ious_loss</span> <span class="o">+</span> <span class="n">outside_boxes_and_center_cost_coeff</span> <span class="o">*</span> <span class="p">(</span><span class="o">~</span><span class="n">is_in_boxes_and_center</span><span class="p">)</span>
     
     
             <span class="c1"># further filter foregrounds: create pairs between cells and ground truth, based on cost and IoUs</span>
             <span class="c1"># further filter foregrounds: create pairs between cells and ground truth, based on cost and IoUs</span>
    -        <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span> <span class="o">=</span> \
    -            <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_k_matching</span><span class="p">(</span><span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span class="n">gt_classes</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">)</span>
    +        <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_k_matching</span><span class="p">(</span><span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span
             <span class="c1"># discard tensors related to cost</span>
             <span class="c1"># discard tensors related to cost</span>
             <span class="k">del</span> <span class="n">pair_wise_cls_loss</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span class="n">pair_wise_ious_loss</span>
             <span class="k">del</span> <span class="n">pair_wise_cls_loss</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span class="n">pair_wise_ious_loss</span>
     
     
    @@ -507,7 +566,7 @@
     
     
             <span class="k">return</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg</span></div>
             <span class="k">return</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg</span></div>
     
     
    -<div class="viewcode-block" id="YoloXDetectionLoss.get_in_boxes_info"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.get_in_boxes_info">[docs]</a>    <span class="k">def</span> <span class="nf">get_in_boxes_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p
    +<div class="viewcode-block" id="YoloXDetectionLoss.get_in_boxes_info"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.get_in_boxes_info">[docs]</a>    <span class="k">def</span> <span class="nf">get_in_boxes_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</sp
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Create a mask for all cells, mask in only foreground: cells that have a center located:</span>
     <span class="sd">        Create a mask for all cells, mask in only foreground: cells that have a center located:</span>
     <span class="sd">            * withing a GT box;</span>
     <span class="sd">            * withing a GT box;</span>
    @@ -561,14 +620,18 @@
             <span class="c1"># FIND CELL CENTERS THAT ARE WITHIN +- self.center_sampling_radius CELLS FROM GROUND TRUTH BOXES CENTERS</span>
             <span class="c1"># FIND CELL CENTERS THAT ARE WITHIN +- self.center_sampling_radius CELLS FROM GROUND TRUTH BOXES CENTERS</span>
     
     
             <span class="c1"># define fake boxes: instead of ground truth boxes step +- self.center_sampling_radius from their centers</span>
             <span class="c1"># define fake boxes: instead of ground truth boxes step +- self.center_sampling_radius from their centers</span>
    -        <span class="n">gt_bboxes_per_image_l</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
    -                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
    -        <span class="n">gt_bboxes_per_image_r</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
    -                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
    -        <span class="n">gt_bboxes_per_image_t</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
    -                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
    -        <span class="n">gt_bboxes_per_image_b</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
    -                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
    +        <span class="n">gt_bboxes_per_image_l</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
    +            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
    +        <span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    +        <span class="n">gt_bboxes_per_image_r</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
    +            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
    +        <span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    +        <span class="n">gt_bboxes_per_image_t</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
    +            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
    +        <span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    +        <span class="n">gt_bboxes_per_image_b</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
    +            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
    +        <span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
     
     
             <span class="n">c_l</span> <span class="o">=</span> <span class="n">x_centers_per_image</span> <span class="o">-</span> <span class="n">gt_bboxes_per_image_l</span>
             <span class="n">c_l</span> <span class="o">=</span> <span class="n">x_centers_per_image</span> <span class="o">-</span> <span class="n">gt_bboxes_per_image_l</span>
             <span class="n">c_r</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image_r</span> <span class="o">-</span> <span class="n">x_centers_per_image</span>
             <span class="n">c_r</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image_r</span> <span class="o">-</span> <span class="n">x_centers_per_image</span>
    @@ -582,10 +645,10 @@
             <span class="n">is_in_boxes_anchor</span> <span class="o">=</span> <span class="n">is_in_boxes_all</span> <span class="o">|</span> <span class="n">is_in_centers_all</span>
             <span class="n">is_in_boxes_anchor</span> <span class="o">=</span> <span class="n">is_in_boxes_all</span> <span class="o">|</span> <span class="n">is_in_centers_all</span>
     
     
             <span class="c1"># in boxes AND in centers, preserving a shape [num_GTs x num_FGs]</span>
             <span class="c1"># in boxes AND in centers, preserving a shape [num_GTs x num_FGs]</span>
    -        <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="p">(</span><span class="n">is_in_boxes</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">]</span> <span class="o">&amp;</span> <span class="n">is_in_centers</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">])</span>
    +        <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="n">is_in_boxes</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">]</span> <span class="o">&amp;</span> <span class="n">is_in_centers</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">]</span>
             <span class="k">return</span> <span class="n">is_in_boxes_anchor</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span></div>
             <span class="k">return</span> <span class="n">is_in_boxes_anchor</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span></div>
     
     
    -<div class="viewcode-block" id="YoloXDetectionLoss.dynamic_k_matching"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.dynamic_k_matching">[docs]</a>    <span class="k">def</span> <span class="nf">dynamic_k_matching</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <sp
    +<div class="viewcode-block" id="YoloXDetectionLoss.dynamic_k_matching"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.dynamic_k_matching">[docs]</a>    <span class="k">def</span> <span class="nf">dynamic_k_matching</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span clas
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        :param cost:            pairwise cost, [num_FGs x num_GTs]</span>
     <span class="sd">        :param cost:            pairwise cost, [num_FGs x num_GTs]</span>
     <span class="sd">        :param pair_wise_ious:  pairwise IoUs, [num_FGs x num_GTs]</span>
     <span class="sd">        :param pair_wise_ious:  pairwise IoUs, [num_FGs x num_GTs]</span>
    @@ -632,6 +695,399 @@
     
     
             <span class="n">pred_ious_this_matching</span> <span class="o">=</span> <span class="p">(</span><span class="n">matching_matrix</span> <span class="o">*</span> <span class="n">pair_wise_ious</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)[</span><span class="n">fg_mask_inboxes</span><span class="p">]</span>
             <span class="n">pred_ious_this_matching</span> <span class="o">=</span> <span class="p">(</span><span class="n">matching_matrix</span> <span class="o">*</span> <span class="n">pair_wise_ious</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)[</span><span class="n">fg_mask_inboxes</span><span class="p">]</span>
             <span class="k">return</span> <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span></div></div>
             <span class="k">return</span> <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span></div></div>
    +
    +
    +<div class="viewcode-block" id="YoloXFastDetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXFastDetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoloXFastDetectionLoss</span><span class="p">(</span><span class="n">YoloXDetectionLoss</span><span class="p">):</span>
    +    <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">    A completely new implementation of YOLOX loss.</span>
    +<span class="sd">    This is NOT an equivalent implementation to the regular yolox loss.</span>
    +
    +<span class="sd">    * Completely avoids using loops compared to the nested loops in the original implementation.</span>
    +<span class="sd">        As a result runs much faster (speedup depends on the type of GPUs, their count, the batch size, etc.).</span>
    +<span class="sd">    * Tensors format is very different the original implementation.</span>
    +<span class="sd">        Tensors contain image ids, ground truth ids and anchor ids as values to support variable length data.</span>
    +<span class="sd">    * There are differences in terms of the algorithm itself:</span>
    +<span class="sd">    1. When computing a dynamic k for a ground truth,</span>
    +<span class="sd">        in the original implementation they consider the sum of top 10 predictions sorted by ious among the initial</span>
    +<span class="sd">        foregrounds of any ground truth in the image,</span>
    +<span class="sd">        while in our implementation we consider only the initial foreground of that particular ground truth.</span>
    +<span class="sd">        To compensate for that difference we introduce the dynamic_ks_bias hyperparamter which makes the dynamic ks larger.</span>
    +<span class="sd">    2. When computing the k matched detections for a ground truth,</span>
    +<span class="sd">        in the original implementation they consider the initial foregrounds of any ground truth in the image as candidates,</span>
    +<span class="sd">        while in our implementation we consider only the initial foreground of that particular ground truth as candidates.</span>
    +<span class="sd">        We believe that this difference is minor.</span>
    +
    +<span class="sd">    :param dynamic_ks_bias: hyperparameter to compensate for the discrepancies between the regular loss and this loss.</span>
    +<span class="sd">    :param sync_num_fgs:    sync num of fgs.</span>
    +<span class="sd">                            Can be used for DDP training.</span>
    +<span class="sd">    :param obj_loss_fix:    devide by total of num anchors instead num of matching fgs.</span>
    +<span class="sd">                            Can be used for objectness loss.</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">use_l1</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center_sampling_radius</span><span class="o">=</span><span class="mf">2.5</span><span class="p">,</span> <span class="n">iou_type</span><span class="o">=</span><span class="s2">&quot;iou&quot;</span><s
    +    <span class="p">):</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">strides</span><span class="o">=</span><span class="n">strides</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">use_l1</span><span class="o">=</span><span class="n">use_l1</span><span class="p">,</span> <span class="n">center_s
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_ks_bias</span> <span class="o">=</span> <span class="n">dynamic_ks_bias</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">sync_num_fgs</span> <span class="o">=</span> <span class="n">sync_num_fgs</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">obj_loss_fix</span> <span class="o">=</span> <span class="n">obj_loss_fix</span>
    +
    +    <span class="k">def</span> <span class="nf">_compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        L = L_objectness + L_iou + L_classification + 1[no_aug_epoch]*L_l1</span>
    +<span class="sd">        where:</span>
    +<span class="sd">            * L_iou, L_classification and L_l1 are calculated only between cells and targets that suit them;</span>
    +<span class="sd">            * L_objectness is calculated for all cells.</span>
    +
    +<span class="sd">        L_classification:</span>
    +<span class="sd">            for cells that have suitable ground truths in their grid locations add BCEs</span>
    +<span class="sd">            to force a prediction of IoU with a GT in a multi-label way</span>
    +<span class="sd">            Coef: 1.</span>
    +<span class="sd">        L_iou:</span>
    +<span class="sd">            for cells that have suitable ground truths in their grid locations</span>
    +<span class="sd">            add (1 - IoU^2), IoU between a predicted box and each GT box, force maximum IoU</span>
    +<span class="sd">            Coef: 1.</span>
    +<span class="sd">        L_l1:</span>
    +<span class="sd">            for cells that have suitable ground truths in their grid locations</span>
    +<span class="sd">            l1 distance between the logits and GTs in “logits” format (the inverse of “logits to predictions” ops)</span>
    +<span class="sd">            Coef: 1[no_aug_epoch]</span>
    +<span class="sd">        L_objectness:</span>
    +<span class="sd">            for each cell add BCE with a label of 1 if there is GT assigned to the cell</span>
    +<span class="sd">            Coef: 5</span>
    +
    +<span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
    +<span class="sd">                                [Batch x Num_Anchors x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
    +<span class="sd">        :param targets:         [Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h</span>
    +
    +<span class="sd">        :return:                loss, all losses separately in a detached tensor</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">transformed_outputs</span><span class="p">,</span> <span class="n">raw_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_predictions</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
    +
    +        <span class="n">bbox_preds</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>  <span class="c1"># [batch, n_anchors_all, 4]</span>
    +        <span class="n">obj_preds</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span>  <span class="c1"># [batch, n_anchors_all, 1]</span>
    +        <span class="n">cls_preds</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">5</span><span class="p">:]</span>  <span class="c1"># [batch, n_anchors_all, n_cls]</span>
    +
    +        <span class="c1"># assign cells to ground truths, at most one GT per cell</span>
    +        <span class="n">matched_fg_ids</span><span class="p">,</span> <span class="n">matched_gt_classes</span><span class="p">,</span> <span class="n">matched_gt_ids</span><span class="p">,</span> <span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_ious</span><span class="p">,</span> <span class="n">flattened_gts</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_matching</span><span class="p">(</span>
    +            <span class="n">bbox_preds</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span> <span class="n">obj_preds</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">targets</span>
    +        <span class="p">)</span>
    +
    +        <span class="n">num_gts</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">flattened_gts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
    +        <span class="n">num_fg</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">matched_gt_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
    +        <span class="n">total_num_anchors</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class=
    +
    +        <span class="n">cls_targets</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">matched_gt_classes</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">)</span
    +        <span class="n">obj_targets</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</s
    +        <span class="n">obj_targets</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
    +        <span class="n">reg_targets</span> <span class="o">=</span> <span class="n">flattened_gts</span><span class="p">[</span><span class="n">matched_gt_ids</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:]</span>
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
    +            <span class="n">l1_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_l1_target</span><span class="p">(</span>
    +                <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">num_fg</span><span class="p">,</span> <span class="mi">4</span><span class="p">)),</span>
    +                <span class="n">flattened_gts</span><span class="p">[</span><span class="n">matched_gt_ids</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:],</span>
    +                <span class="n">expanded_strides</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()[</span><span class="n">matched_fg_ids</span><span class="p">],</span>
    +                <span class="n">x_shifts</span><span class="o">=</span><span class="n">x_shifts</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()[</span><span class="n">matched_fg_ids</span><span class="p">],</span>
    +                <span class="n">y_shifts</span><span class="o">=</span><span class="n">y_shifts</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()[</span><span class="n">matched_fg_ids</span><span class="p">],</span>
    +            <span class="p">)</span>
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sync_num_fgs</span> <span class="ow">and</span> <span class="n">dist</span><span class="o">.</span><span class="n">group</span><span class="o">.</span><span class="n">WORLD</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="n">num_fg</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">scalar_tensor</span><span class="p">(</span><span class="n">num_fg</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">matched_gt_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    +            <span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">num_fg</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">_C</span><span class="o">.</span><span class="n">_distributed_c10d</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">AVG</span><span class="p">)</span>
    +
    +        <span class="n">loss_iou</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_loss</span><span class="p">(</span><span class="n">bbox_preds</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">],</span> <span class="n">reg_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span clas
    +        <span class="n">loss_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span><span class="p">(</span><span class="n">obj_preds</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">obj_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <
    +        <span class="n">loss_cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span><span class="p">(</span><span class="n">cls_preds</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">],</span> <span class="n">cls_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <spa
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
    +            <span class="n">loss_l1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l1_loss</span><span class="p">(</span><span class="n">raw_outputs</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">],</span> <span class="n">l1_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span cl
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="n">loss_l1</span> <span class="o">=</span> <span class="mf">0.0</span>
    +
    +        <span class="n">reg_weight</span> <span class="o">=</span> <span class="mf">5.0</span>
    +        <span class="n">loss</span> <span class="o">=</span> <span class="n">reg_weight</span> <span class="o">*</span> <span class="n">loss_iou</span> <span class="o">+</span> <span class="n">loss_obj</span> <span class="o">+</span> <span class="n">loss_cls</span> <span class="o">+</span> <span class="n">loss_l1</span>
    +
    +        <span class="k">return</span> <span class="p">(</span>
    +            <span class="n">loss</span><span class="p">,</span>
    +            <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
    +                <span class="p">(</span>
    +                    <span class="n">loss_iou</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                    <span class="n">loss_obj</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                    <span class="n">loss_cls</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">loss_l1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">device</span><span class="p">),</spa
    +                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">num_fg</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">num_gts</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span clas
    +                    <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
    +                <span class="p">)</span>
    +            <span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span>
    +        <span class="p">)</span>
    +
    +    <span class="k">def</span> <span class="nf">_get_initial_matching</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span> <span class="n">gt_bboxes</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</sp
    +    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Get candidates using a mask for all cells.</span>
    +<span class="sd">        Mask in only foreground cells that have a center located:</span>
    +<span class="sd">            * withing a GT box (param: is_in_boxes);</span>
    +<span class="sd">            OR</span>
    +<span class="sd">            * within a fixed radius around a GT box (center sampling) (param: is_in_centers);</span>
    +
    +<span class="sd">        return:</span>
    +<span class="sd">            initial_matching: get a list of candidates pairs of (gt box id, anchor box id) based on cell = is_in_boxes | is_in_centers.</span>
    +<span class="sd">                              shape: [num_candidates, 2]</span>
    +<span class="sd">            strong candidate mask: get a list whether a candidate is a strong one or not.</span>
    +<span class="sd">                                   strong candidate is a cell from is_in_boxes &amp; is_in_centers.</span>
    +<span class="sd">                                   shape: [num_candidates].</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">cell_x_centers</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shifts</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">expanded_strides</span>
    +        <span class="n">cell_y_centers</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_shifts</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">expanded_strides</span>
    +
    +        <span class="n">gt_bboxes_x_centers</span> <span class="o">=</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    +        <span class="n">gt_bboxes_y_centers</span> <span class="o">=</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    +
    +        <span class="n">gt_bboxes_half_w</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    +        <span class="n">gt_bboxes_half_h</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    +
    +        <span class="n">is_in_boxes</span> <span class="o">=</span> <span class="p">(</span>
    +            <span class="p">(</span><span class="n">cell_x_centers</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_x_centers</span> <span class="o">-</span> <span class="n">gt_bboxes_half_w</span><span class="p">)</span>
    +            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_x_centers</span> <span class="o">+</span> <span class="n">gt_bboxes_half_w</span> <span class="o">&gt;</span> <span class="n">cell_x_centers</span><span class="p">)</span>
    +            <span class="o">&amp;</span> <span class="p">(</span><span class="n">cell_y_centers</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_y_centers</span> <span class="o">-</span> <span class="n">gt_bboxes_half_h</span><span class="p">)</span>
    +            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_y_centers</span> <span class="o">+</span> <span class="n">gt_bboxes_half_h</span> <span class="o">&gt;</span> <span class="n">cell_y_centers</span><span class="p">)</span>
    +        <span class="p">)</span>
    +
    +        <span class="n">radius_shifts</span> <span class="o">=</span> <span class="mf">2.5</span> <span class="o">*</span> <span class="n">expanded_strides</span>
    +
    +        <span class="n">is_in_centers</span> <span class="o">=</span> <span class="p">(</span>
    +            <span class="p">(</span><span class="n">cell_x_centers</span> <span class="o">+</span> <span class="n">radius_shifts</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_x_centers</span><span class="p">)</span>
    +            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_x_centers</span> <span class="o">&gt;</span> <span class="n">cell_x_centers</span> <span class="o">-</span> <span class="n">radius_shifts</span><span class="p">)</span>
    +            <span class="o">&amp;</span> <span class="p">(</span><span class="n">cell_y_centers</span> <span class="o">+</span> <span class="n">radius_shifts</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_y_centers</span><span class="p">)</span>
    +            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_y_centers</span> <span class="o">&gt;</span> <span class="n">cell_y_centers</span> <span class="o">-</span> <span class="n">radius_shifts</span><span class="p">)</span>
    +        <span class="p">)</span>
    +
    +        <span class="n">initial_mask</span> <span class="o">=</span> <span class="n">is_in_boxes</span> <span class="o">|</span> <span class="n">is_in_centers</span>
    +        <span class="n">initial_matching</span> <span class="o">=</span> <span class="n">initial_mask</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span>
    +        <span class="n">strong_candidate_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">is_in_boxes</span> <span class="o">&amp;</span> <span class="n">is_in_centers</span><span class="p">)[</span><span class="n">initial_mask</span><span class="p">]</span>
    +
    +        <span class="k">return</span> <span class="n">initial_matching</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">initial_matching</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">strong_candidate_mask</span>
    +
    +    <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
    +    <span class="k">def</span> <span class="nf">_compute_matching</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">bbox_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">cls_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">obj_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">expanded_strides</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">x_shifts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">y_shifts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">ious_loss_cost_coeff</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3.0</span><span class="p">,</span>
    +        <span class="n">outside_boxes_and_center_cost_coeff</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">100000.0</span><span class="p">,</span>
    +    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Ten
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Match cells to ground truth:</span>
    +<span class="sd">            * at most 1 GT per cell</span>
    +<span class="sd">            * dynamic number of cells per GT</span>
    +
    +<span class="sd">        :param bbox_preds: predictions of bounding boxes. shape [batch, n_anchors_all, 4]</span>
    +<span class="sd">        :param cls_preds:  predictions of class.          shape [batch, n_anchors_all, n_cls]</span>
    +<span class="sd">        :param obj_preds:  predictions for objectness.    shape [batch, n_anchors_all, 1]</span>
    +<span class="sd">        :param expanded_strides:  stride of the output grid the prediction is coming from. shape [1, n_anchors_all]</span>
    +<span class="sd">        :param x_shifts: x coordinate on the grid cell the prediction is coming from.      shape [1, n_anchors_all]</span>
    +<span class="sd">        :param y_shifts: y coordinate on the grid cell the prediction is coming from.      shape [1, n_anchors_all]</span>
    +<span class="sd">        :param labels:   labels for each grid cell.  shape [n_anchors_all, (4 + 2)]</span>
    +<span class="sd">        :return: candidate_fg_ids       shape [num_fg]</span>
    +<span class="sd">                 candidate_gt_classes   shape [num_fg]</span>
    +<span class="sd">                 candidate_gt_ids       shape [num_fg]</span>
    +<span class="sd">                 candidate_img_ids      shape [num_fg]</span>
    +<span class="sd">                 candidate_ious         shape [num_fg]</span>
    +<span class="sd">                 flattened_gts          shape [num_gts, 5]</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +
    +        <span class="n">flattened_gts</span><span class="p">,</span> <span class="n">gt_id_to_img_id</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</spa
    +
    +        <span class="c1"># COMPUTE CANDIDATES</span>
    +        <span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="n">candidate_fg_ids</span><span class="p">,</span> <span class="n">strong_candidate_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_initial_matching</span><span class="p">(</span><span class="n">flattened_gts</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="n">expanded_strides</span><span class="p">,</
    +        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">gt_id_to_img_id</span><span class="p">[</span><span class="n">candidate_gt_ids</span><span class="p">]</span>
    +        <span class="n">candidate_gts_bbox</span> <span class="o">=</span> <span class="n">flattened_gts</span><span class="p">[</span><span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span>
    +        <span class="n">candidate_det_bbox</span> <span class="o">=</span> <span class="n">bbox_preds</span><span class="p">[</span><span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_fg_ids</span><span class="p">]</span>
    +
    +        <span class="c1"># COMPUTE DYNAMIC KS</span>
    +        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calculate_pairwise_bbox_iou</span><span class="p">(</span><span class="n">candidate_gts_bbox</span><span class="p">,</span> <span class="n">candidate_det_bbox</span><span class="p">,</span> <span class="n">xyxy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    +        <span class="n">dynamic_ks</span><span class="p">,</span> <span class="n">matching_index_to_dynamic_k_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_dynamic_ks</span><span class="p">(</span><span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="n">candidate_ious</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_ks_bias</span><span class="p">)
    +        <span class="k">del</span> <span class="n">candidate_gts_bbox</span><span class="p">,</span> <span class="n">candidate_det_bbox</span>
    +
    +        <span class="c1"># ORDER CANDIDATES BY COST</span>
    +        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">flattened_gts</span><span class="p">[</span><span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
    +        <span class="n">cost_order</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_cost_order</span><span class="p">(</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span>
    +            <span class="n">candidate_img_ids</span><span class="p">,</span>
    +            <span class="n">candidate_gt_classes</span><span class="p">,</span>
    +            <span class="n">candidate_fg_ids</span><span class="p">,</span>
    +            <span class="n">candidate_ious</span><span class="p">,</span>
    +            <span class="n">cls_preds</span><span class="p">,</span>
    +            <span class="n">obj_preds</span><span class="p">,</span>
    +            <span class="n">strong_candidate_mask</span><span class="p">,</span>
    +            <span class="n">ious_loss_cost_coeff</span><span class="p">,</span>
    +            <span class="n">outside_boxes_and_center_cost_coeff</span><span class="p">,</span>
    +        <span class="p">)</span>
    +
    +        <span class="n">candidate_gt_ids</span> <span class="o">=</span> <span class="n">candidate_gt_ids</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
    +        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">candidate_gt_classes</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
    +        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">candidate_img_ids</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
    +        <span class="n">candidate_fg_ids</span> <span class="o">=</span> <span class="n">candidate_fg_ids</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
    +        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="n">candidate_ious</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
    +        <span class="n">matching_index_to_dynamic_k_index</span> <span class="o">=</span> <span class="n">matching_index_to_dynamic_k_index</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
    +        <span class="k">del</span> <span class="n">cost_order</span>
    +
    +        <span class="c1"># FILTER MATCHING TO LOWEST K COST MATCHES PER GT</span>
    +        <span class="n">ranks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_ranks</span><span class="p">(</span><span class="n">candidate_gt_ids</span><span class="p">)</span>
    +        <span class="n">corresponding_dynamic_ks</span> <span class="o">=</span> <span class="n">dynamic_ks</span><span class="p">[</span><span class="n">matching_index_to_dynamic_k_index</span><span class="p">]</span>
    +        <span class="n">topk_mask</span> <span class="o">=</span> <span class="n">ranks</span> <span class="o">&lt;</span> <span class="n">corresponding_dynamic_ks</span>
    +
    +        <span class="n">candidate_gt_ids</span> <span class="o">=</span> <span class="n">candidate_gt_ids</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
    +        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">candidate_gt_classes</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
    +        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">candidate_img_ids</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
    +        <span class="n">candidate_fg_ids</span> <span class="o">=</span> <span class="n">candidate_fg_ids</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
    +        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="n">candidate_ious</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
    +        <span class="k">del</span> <span class="n">ranks</span><span class="p">,</span> <span class="n">topk_mask</span><span class="p">,</span> <span class="n">dynamic_ks</span><span class="p">,</span> <span class="n">matching_index_to_dynamic_k_index</span><span class="p">,</span> <span class="n">corresponding_dynamic_ks</span>
    +
    +        <span class="c1"># FILTER MATCHING TO AT MOST 1 MATCH FOR DET BY TAKING THE LOWEST COST MATCH</span>
    +        <span class="n">candidate_img_and_fg_ids_combined</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_combine_candidates_img_id_fg_id</span><span class="p">(</span><span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_fg_ids</span><span class="p">)</span>
    +        <span class="n">top1_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_is_first_mask</span><span class="p">(</span><span class="n">candidate_img_and_fg_ids_combined</span><span class="p">)</span>
    +        <span class="n">candidate_gt_ids</span> <span class="o">=</span> <span class="n">candidate_gt_ids</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
    +        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">candidate_gt_classes</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
    +        <span class="n">candidate_fg_ids</span> <span class="o">=</span> <span class="n">candidate_fg_ids</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
    +        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">candidate_img_ids</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
    +        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="n">candidate_ious</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
    +
    +        <span class="k">return</span> <span class="n">candidate_fg_ids</span><span class="p">,</span> <span class="n">candidate_gt_classes</span><span class="p">,</span> <span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_ious</span><span class="p">,</span> <span class="n">flattened_gts</span>
    +
    +    <span class="k">def</span> <span class="nf">_combine_candidates_img_id_fg_id</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">):</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Create one dim tensor with unique pairs of img_id and fg_id.</span>
    +<span class="sd">        e.g: candidate_img_ids = [0,1,0,0]</span>
    +<span class="sd">             candidate_fg_ids = [0,0,0,1]</span>
    +<span class="sd">             result = [0,1,0,2]</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">candidate_img_and_fg_ids_combined</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unique</span><span class="p">(<
    +        <span class="k">return</span> <span class="n">candidate_img_and_fg_ids_combined</span>
    +
    +    <span class="k">def</span> <span class="nf">_compute_dynamic_ks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">ious</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">dynamic_ks_bias</
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        :param ids:                 ids of GTs, shape: [num_candidates]</span>
    +<span class="sd">        :param ious:                pairwise IoUs, shape: [num_candidates]</span>
    +<span class="sd">        :param dynamic_ks_bias:     multiply the resulted k to compensate the regular loss</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;ids must be of shape [num_candidates]&quot;</span>
    +        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">ious</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;ious must be of shape [num_candidates]&quot;</span>
    +        <span class="k">assert</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">ious</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;num of ids.shape[0] must be the same as num of ious.shape[0]&quot;</span>
    +        <span class="c1"># sort ious and ids by ious</span>
    +        <span class="n">ious</span><span class="p">,</span> <span class="n">ious_argsort</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +        <span class="n">ids</span> <span class="o">=</span> <span class="n">ids</span><span class="p">[</span><span class="n">ious_argsort</span><span class="p">]</span>
    +
    +        <span class="c1"># stable sort indices, so that ious are first sorted by id and second by value</span>
    +        <span class="n">ids</span><span class="p">,</span> <span class="n">ids_argsort</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +        <span class="n">ious</span> <span class="o">=</span> <span class="n">ious</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span>
    +
    +        <span class="n">unique_ids</span><span class="p">,</span> <span class="n">ids_index_to_unique_ids_index</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">unique_consecutive</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">return_inverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +        <span class="n">num_unique_ids</span> <span class="o">=</span> <span class="n">unique_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    +
    +        <span class="k">if</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">10</span><span class="p">:</span>
    +            <span class="n">is_in_top_10</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span clas
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="n">is_in_top_10</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
    +
    +        <span class="n">dynamic_ks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_unique_ids</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ious</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">ious</span><span class="o">.</span><sp
    +        <span class="n">dynamic_ks</span><span class="o">.</span><span class="n">index_put_</span><span class="p">((</span><span class="n">ids_index_to_unique_ids_index</span><span class="p">,),</span> <span class="n">is_in_top_10</span> <span class="o">*</span> <span class="n">ious</span><span class="p">,</span> <span class="n">accumulate</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +        <span class="k">if</span> <span class="n">dynamic_ks_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="n">dynamic_ks</span> <span class="o">*=</span> <span class="n">dynamic_ks_bias</span>
    +        <span class="n">dynamic_ks</span> <span class="o">=</span> <span class="n">dynamic_ks</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    +
    +        <span class="n">all_argsort</span> <span class="o">=</span> <span class="n">ious_argsort</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span>
    +        <span class="n">inverse_all_argsort</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ious_argsort</span><span class="p">)</span>
    +        <span class="n">inverse_all_argsort</span><span class="p">[</span><span class="n">all_argsort</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">all_argsort</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">all_arg
    +
    +        <span class="k">return</span> <span class="n">dynamic_ks</span><span class="p">,</span> <span class="n">ids_index_to_unique_ids_index</span><span class="p">[</span><span class="n">inverse_all_argsort</span><span class="p">]</span>
    +
    +    <span class="k">def</span> <span class="nf">_compute_cost_order</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">num_classes</span><span class="p">,</span>
    +        <span class="n">candidate_gt_img_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">candidate_gt_classes</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">candidate_anchor_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">candidate_ious</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">cls_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">obj_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">strong_candidate_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    +        <span class="n">ious_loss_cost_coeff</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
    +        <span class="n">outside_boxes_and_center_cost_coeff</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
    +    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    +        <span class="n">gt_cls_per_image</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">candidate_gt_classes</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">),</span> <span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">floa
    +        <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">amp</span><span class="o">.</span><span class="n">autocast</span><span class="p">(</span><span class="n">enabled</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    +            <span class="n">cls_preds_</span> <span class="o">=</span> <span class="p">(</span>
    +                <span class="n">cls_preds</span><span class="p">[</span><span class="n">candidate_gt_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">]</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">sigmoid_</span><span class="p">()</span>
    +                <span class="o">*</span> <span class="n">obj_preds</span><span class="p">[</span><span class="n">candidate_gt_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">]</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">sigmoid_</span><span class="p">()</span>
    +            <span class="p">)</span>
    +            <span class="n">pair_wise_cls_cost</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">cls_preds_</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">(),</span> <span class="n">gt_cls_per_image</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span><span cla
    +
    +        <span class="n">ious_cost</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">candidate_ious</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span>
    +        <span class="n">cost</span> <span class="o">=</span> <span class="n">pair_wise_cls_cost</span> <span class="o">+</span> <span class="n">ious_loss_cost_coeff</span> <span class="o">*</span> <span class="n">ious_cost</span> <span class="o">+</span> <span class="n">outside_boxes_and_center_cost_coeff</span> <span class="o">*</span> <span class="n">strong_candidate_mask</span><span class="o">.</span><span class="n">logical_not</span><span class="p">()</span>
    +        <span class="k">return</span> <span class="n">cost</span><span class="o">.</span><span class="n">argsort</span><span class="p">()</span>
    +
    +    <span class="k">def</span> <span class="nf">_calculate_pairwise_bbox_iou</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">bboxes_a</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">bboxes_b</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n"
    +        <span class="k">if</span> <span class="n">bboxes_a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">4</span> <span class="ow">or</span> <span class="n">bboxes_b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</sp
    +            <span class="k">raise</span> <span class="ne">IndexError</span>
    +
    +        <span class="k">if</span> <span class="n">xyxy</span><span class="p">:</span>
    +            <span class="n">tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">])</span>
    +            <span class="n">br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:],</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:])</span>
    +            <span class="n">area_a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">-</span> <span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
    +            <span class="n">area_b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">-</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="n">tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
    +                <span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
    +                <span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
    +            <span class="p">)</span>
    +            <span class="n">br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span>
    +                <span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
    +                <span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
    +            <span class="p">)</span>
    +
    +            <span class="n">area_a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:],</span> <span class="mi">1</span><span class="p">)</span>
    +            <span class="n">area_b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:],</span> <span class="mi">1</span><span class="p">)</span>
    +        <span class="n">en</span> <span class="o">=</span> <span class="p">(</span><span class="n">tl</span> <span class="o">&lt;</span> <span class="n">br</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    +        <span class="n">area_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">br</span> <span class="o">-</span> <span class="n">tl</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">en</span>
    +        <span class="k">return</span> <span class="n">area_i</span> <span class="o">/</span> <span class="p">(</span><span class="n">area_a</span> <span class="o">+</span> <span class="n">area_b</span> <span class="o">-</span> <span class="n">area_i</span><span class="p">)</span>
    +
    +    <span class="k">def</span> <span class="nf">_compute_ranks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    +        <span class="n">ids</span><span class="p">,</span> <span class="n">ids_argsort</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +
    +        <span class="k">if</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    +            <span class="n">is_not_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span clas
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="n">is_not_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
    +
    +        <span class="n">subtract</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="
    +        <span class="n">subtract</span><span class="p">[</span><span class="n">is_not_first</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
    +        <span class="n">subtract</span> <span class="o">=</span> <span class="n">subtract</span><span class="o">.</span><span class="n">cummax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    +        <span class="n">rank</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">d
    +
    +        <span class="n">inverse_argsort</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ids_argsort</span><span class="p">)</span>
    +        <span class="n">inverse_argsort</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort
    +
    +        <span class="k">return</span> <span class="n">rank</span><span class="p">[</span><span class="n">inverse_argsort</span><span class="p">]</span>
    +
    +    <span class="k">def</span> <span class="nf">_compute_is_first_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Filter fg that matches two gts.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">ids</span><span class="p">,</span> <span class="n">ids_argsort</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +
    +        <span class="k">if</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    +            <span class="n">is_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n"
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="n">is_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
    +
    +        <span class="n">inverse_argsort</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ids_argsort</span><span class="p">)</span>
    +        <span class="n">inverse_argsort</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort
    +
    +        <span class="k">return</span> <span class="n">is_first</span><span class="p">[</span><span class="n">inverse_argsort</span><span class="p">]</span></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -661,4 +1117,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.metrics.classification_metrics &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.metrics.classification_metrics &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -92,7 +94,7 @@
     <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">Metric</span>
     <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">Metric</span>
     
     
     
     
    -<div class="viewcode-block" id="accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.accuracy">[docs]</a><span class="k">def</span> <span class="nf">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">topk</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,)):</span>
    +<div class="viewcode-block" id="accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.accuracy">[docs]</a><span class="k">def</span> <span class="nf">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">topk</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,)):</span>
         <span class="sd">&quot;&quot;&quot;Computes the precision@k for the specified values of k</span>
         <span class="sd">&quot;&quot;&quot;Computes the precision@k for the specified values of k</span>
     <span class="sd">    :param output: Tensor / Numpy / List</span>
     <span class="sd">    :param output: Tensor / Numpy / List</span>
     <span class="sd">        The prediction</span>
     <span class="sd">        The prediction</span>
    @@ -122,24 +124,26 @@
         <span class="k">return</span> <span class="n">res</span></div>
         <span class="k">return</span> <span class="n">res</span></div>
     
     
     
     
    -<div class="viewcode-block" id="Accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Accuracy">[docs]</a><span class="k">class</span> <span class="nc">Accuracy</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">):</span>
    +<div class="viewcode-block" id="Accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Accuracy">[docs]</a><span class="k">class</span> <span class="nc">Accuracy</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
     
     
    -<div class="viewcode-block" id="Accuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Accuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class=
    +<div class="viewcode-block" id="Accuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Accuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">targ
             <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
                 <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
                 <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
     
     
     
     
    -<div class="viewcode-block" id="Top5"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Top5">[docs]</a><span class="k">class</span> <span class="nc">Top5</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
    +<div class="viewcode-block" id="Top5"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Top5">[docs]</a><span class="k">class</span> <span class="nc">Top5</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><spa
    +        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><sp
             <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><span c
             <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><span c
     
     
    -<div class="viewcode-block" id="Top5.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Top5.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">targ
    +<div class="viewcode-block" id="Top5.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Top5.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span
             <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
                 <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
                 <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
     
     
    @@ -155,21 +159,22 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">correct</span> <span class="o">+=</span> <span class="n">correct5</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">correct</span> <span class="o">+=</span> <span class="n">correct5</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">total</span> <span class="o">+=</span> <span class="n">batch_size</span></div>
             <span class="bp">self</span><span class="o">.</span><span class="n">total</span> <span class="o">+=</span> <span class="n">batch_size</span></div>
     
     
    -<div class="viewcode-block" id="Top5.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Top5.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +<div class="viewcode-block" id="Top5.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Top5.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">correct</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span></div></div>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">correct</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span></div></div>
     
     
     
     
    -<div class="viewcode-block" id="ToyTestClassificationMetric"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.ToyTestClassificationMetric">[docs]</a><span class="k">class</span> <span class="nc">ToyTestClassificationMetric</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
    +<div class="viewcode-block" id="ToyTestClassificationMetric"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.ToyTestClassificationMetric">[docs]</a><span class="k">class</span> <span class="nc">ToyTestClassificationMetric</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Dummy classification Mettric object returning 0 always (for testing).</span>
     <span class="sd">    Dummy classification Mettric object returning 0 always (for testing).</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    +
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="ToyTestClassificationMetric.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.ToyTestClassificationMetric.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span
    +<div class="viewcode-block" id="ToyTestClassificationMetric.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.ToyTestClassificationMetric.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span c
             <span class="k">pass</span></div>
             <span class="k">pass</span></div>
     
     
    -<div class="viewcode-block" id="ToyTestClassificationMetric.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.ToyTestClassificationMetric.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +<div class="viewcode-block" id="ToyTestClassificationMetric.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.ToyTestClassificationMetric.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="k">return</span> <span class="mi">0</span></div></div>
             <span class="k">return</span> <span class="mi">0</span></div></div>
     </pre></div>
     </pre></div>
     
     
    @@ -200,4 +205,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.metrics.detection_metrics &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.metrics.detection_metrics &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -94,10 +96,11 @@
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">compute_detection_matching</span><span class="p">,</span> <span class="n">compute_detection_metrics</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">compute_detection_matching</span><span class="p">,</span> <span class="n">compute_detection_metrics</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionPostPredictionCallback</span><span class="p">,</span> <span class="n">IouThreshold</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionPostPredictionCallback</span><span class="p">,</span> <span class="n">IouThreshold</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    +
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="DetectionMetrics"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
    +<div class="viewcode-block" id="DetectionMetrics"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    DetectionMetrics</span>
     <span class="sd">    DetectionMetrics</span>
     
     
    @@ -121,37 +124,52 @@
     <span class="sd">        accumulate_on_cpu:     Run on CPU regardless of device used in other parts.</span>
     <span class="sd">        accumulate_on_cpu:     Run on CPU regardless of device used in other parts.</span>
     <span class="sd">                            This is to avoid &quot;CUDA out of memory&quot; that might happen on GPU (default False)</span>
     <span class="sd">                            This is to avoid &quot;CUDA out of memory&quot; that might happen on GPU (default False)</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    -                 <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -                 <span class="n">iou_thres</span><span class="p">:</span> <span class="n">IouThreshold</span> <span class="o">=</span> <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05_TO_095</span><span class="p">,</span>
    -                 <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    -                 <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
    -                 <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -                 <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
    +
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    +        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">iou_thres</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">IouThreshold</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05_TO_095</span><span class="p">,</span>
    +        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    +        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    +    <span class="p">):</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">num_cls</span> <span class="o">=</span> <span class="n">num_cls</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">num_cls</span> <span class="o">=</span> <span class="n">num_cls</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">iou_thres</span> <span class="o">=</span> <span class="n">iou_thres</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">iou_thres</span> <span class="o">=</span> <span class="n">iou_thres</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span> <span class="o">=</span> <span class="s1">&#39;mAP@</span><span class="si">%.1f</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">iou_thres</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">is_range</span><span class="p">()</
    -        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Precision&quot;</span><span class="p">,</span> <span class="s2">&quot;Recall&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span><span class="p">,</span> <span class="s2">&quot;F1&quot;</span><span class="p">]</span>
    +
    +        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">iou_thres</span><span class="p">,</span> <span class="n">IouThreshold</span><span class="p">):</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">()</span>
    +        <span class="k">else</span><span class="p">:</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">iou_thres</span><span class="p">])</span>
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span> <span class="o">=</span> <span class="s2">&quot;mAP&quot;</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span> <span class="o">=</span> <span class="p">{</span>
    +            <span class="sa">f</span><span class="s2">&quot;Precision</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +            <span class="sa">f</span><span class="s2">&quot;Recall</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +            <span class="sa">f</span><span class="s2">&quot;mAP</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +            <span class="sa">f</span><span class="s2">&quot;F1</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +        <span class="p">}</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">components</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">component_names</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">components</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">component_names</span><span class="p">)</span>
    +
             <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="n">super_gradients</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="n">super_gradients</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">normalize_targets</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">normalize_targets</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="kc">None</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;matching_info&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="p">[],</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="p">[],</span> <span class="n"
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">)</span> <span class="k">if</span> <span class="n">recall_thres</span> <span class="ow">is</span> <span class="kc">None</sp
             <span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">)</span> <span class="k">if</span> <span class="n">recall_thres</span> <span class="ow">is</span> <span class="kc">None</sp
             <span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span> <span class="o">=</span> <span class="n">score_thres</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span> <span class="o">=</span> <span class="n">score_thres</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span> <span class="o">=</span> <span class="n">top_k_predictions</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span> <span class="o">=</span> <span class="n">top_k_predictions</span>
     
     
             <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="o">=</span> <span class="n">accumulate_on_cpu</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="o">=</span> <span class="n">accumulate_on_cpu</span>
     
     
    -<div class="viewcode-block" id="DetectionMetrics.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</s
    -               <span class="n">inputs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">,</span> <span class="n">crowd_targets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +<div class="viewcode-block" id="DetectionMetrics.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><spa
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.</span>
     <span class="sd">        Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.</span>
     
     
    @@ -173,27 +191,38 @@
             <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
             <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
     
     
             <span class="n">new_matching_info</span> <span class="o">=</span> <span class="n">compute_detection_matching</span><span class="p">(</span>
             <span class="n">new_matching_info</span> <span class="o">=</span> <span class="n">compute_detection_matching</span><span class="p">(</span>
    -            <span class="n">preds</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="p">,</span> <span class="n">crowd_targets</span><span class="o">=</span><span class="n">crowd_targets</span><span class="p">,</span>
    -            <span class="n">top_k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span><span class="p">,</span> <span class="n">denormalize_targets</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span><span class="p">,</span>
    -            <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">return_on_cpu</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span><span class="p">)</span>
    -
    -        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">)</span>
    -        <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">,</span> <span class="n">accumulated_matching_info</span> <span class="o">+</span> <span class="n">new_matching_info</span><span class="p">)</span></div>
    -
    -<div class="viewcode-block" id="DetectionMetrics.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Unio
    +            <span class="n">preds</span><span class="p">,</span>
    +            <span class="n">targets</span><span class="p">,</span>
    +            <span class="n">height</span><span class="p">,</span>
    +            <span class="n">width</span><span class="p">,</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="p">,</span>
    +            <span class="n">crowd_targets</span><span class="o">=</span><span class="n">crowd_targets</span><span class="p">,</span>
    +            <span class="n">top_k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span><span class="p">,</span>
    +            <span class="n">denormalize_targets</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span><span class="p">,</span>
    +            <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
    +            <span class="n">return_on_cpu</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span><span class="p">,</span>
    +        <span class="p">)</span>
    +
    +        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    +        <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">accumulated_matching_info</span> <span class="o">+</span> <span class="n">new_ma
    +
    +<div class="viewcode-block" id="DetectionMetrics.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span>
             <span class="sd">&quot;&quot;&quot;Compute the metrics for all the accumulated results.</span>
             <span class="sd">&quot;&quot;&quot;Compute the metrics for all the accumulated results.</span>
    -<span class="sd">            :return: Metrics of interest</span>
    +<span class="sd">        :return: Metrics of interest</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="n">mean_ap</span><span class="p">,</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span clas
    -        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">)</span>
    +        <span class="n">mean_ap</span><span class="p">,</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span c
    +        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
     
     
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">accumulated_matching_info</span><span class="p">):</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">accumulated_matching_info</span><span class="p">):</span>
                 <span class="n">matching_info_tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span clas
                 <span class="n">matching_info_tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span clas
     
     
                 <span class="c1"># shape (n_class, nb_iou_thresh)</span>
                 <span class="c1"># shape (n_class, nb_iou_thresh)</span>
                 <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">unique_classes</span> <span class="o">=</span> <span class="n">compute_detection_metrics</span><span class="p">(</span>
                 <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">unique_classes</span> <span class="o">=</span> <span class="n">compute_detection_metrics</span><span class="p">(</span>
    -                <span class="o">*</span><span class="n">matching_info_tensors</span><span class="p">,</span> <span class="n">recall_thresholds</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span><span class="p">,</span> <span class="n">score_threshold</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span><span class="p">,</span>
    -                <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    +                <span class="o">*</span><span class="n">matching_info_tensors</span><span class="p">,</span>
    +                <span class="n">recall_thresholds</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span><span class="p">,</span>
    +                <span class="n">score_threshold</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span><span class="p">,</span>
    +                <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
    +            <span class="p">)</span>
     
     
                 <span class="c1"># Precision, recall and f1 are computed for smallest IoU threshold (usually 0.5), averaged over classes</span>
                 <span class="c1"># Precision, recall and f1 are computed for smallest IoU threshold (usually 0.5), averaged over classes</span>
                 <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="n">precision</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">recall</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</
                 <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="n">precision</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">recall</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</
    @@ -201,7 +230,12 @@
                 <span class="c1"># MaP is averaged over IoU thresholds and over classes</span>
                 <span class="c1"># MaP is averaged over IoU thresholds and over classes</span>
                 <span class="n">mean_ap</span> <span class="o">=</span> <span class="n">ap</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
                 <span class="n">mean_ap</span> <span class="o">=</span> <span class="n">ap</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
     
     
    -        <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;Precision&quot;</span><span class="p">:</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="s2">&quot;Recall&quot;</span><span class="p">:</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span><span class="p">:</span> <span class="n">mean_ap</span><span class="p">,</span> <span class="s2">
    +        <span class="k">return</span> <span class="p">{</span>
    +            <span class="sa">f</span><span class="s2">&quot;Precision</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_precision</span><span class="p">,</span>
    +            <span class="sa">f</span><span class="s2">&quot;Recall</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_recall</span><span class="p">,</span>
    +            <span class="sa">f</span><span class="s2">&quot;mAP</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_ap</span><span class="p">,</span>
    +            <span class="sa">f</span><span class="s2">&quot;F1</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_f1</span><span class="p">,</span>
    +        <span class="p">}</span></div>
     
     
         <span class="k">def</span> <span class="nf">_sync_dist</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">process_group</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_sync_dist</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">process_group</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    @@ -223,10 +257,83 @@
                 <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_gather_object</span><span class="p">(</span><span class="n">gathered_state_dicts</span><span class="p">,</span> <span class="n">local_state_dict</span><span class="p">)</span>
                 <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_gather_object</span><span class="p">(</span><span class="n">gathered_state_dicts</span><span class="p">,</span> <span class="n">local_state_dict</span><span class="p">)</span>
                 <span class="n">matching_info</span> <span class="o">=</span> <span class="p">[]</span>
                 <span class="n">matching_info</span> <span class="o">=</span> <span class="p">[]</span>
                 <span class="k">for</span> <span class="n">state_dict</span> <span class="ow">in</span> <span class="n">gathered_state_dicts</span><span class="p">:</span>
                 <span class="k">for</span> <span class="n">state_dict</span> <span class="ow">in</span> <span class="n">gathered_state_dicts</span><span class="p">:</span>
    -                <span class="n">matching_info</span> <span class="o">+=</span> <span class="n">state_dict</span><span class="p">[</span><span class="s2">&quot;matching_info&quot;</span><span class="p">]</span>
    +                <span class="n">matching_info</span> <span class="o">+=</span> <span class="n">state_dict</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">]</span>
                 <span class="n">matching_info</span> <span class="o">=</span> <span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">matching_info</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span
                 <span class="n">matching_info</span> <span class="o">=</span> <span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">matching_info</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span
     
     
    -            <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">,</span> <span class="n">matching_info</span><span class="p">)</span></div>
    +            <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">matching_info</span><span class="p">)</span>
    +
    +    <span class="k">def</span> <span class="nf">_get_range_str</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="k">return</span> <span class="s2">&quot;@</span><span class="si">%.2f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh
    +
    +
    +<div class="viewcode-block" id="DetectionMetrics_050"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics_050">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics_050</span><span class="p">(</span><span class="n">DetectionMetrics</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    +        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    +        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    +    <span class="p">):</span>
    +
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    +            <span class="n">num_cls</span><span class="p">,</span>
    +            <span class="n">post_prediction_callback</span><span class="p">,</span>
    +            <span class="n">normalize_targets</span><span class="p">,</span>
    +            <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05</span><span class="p">,</span>
    +            <span class="n">recall_thres</span><span class="p">,</span>
    +            <span class="n">score_thres</span><span class="p">,</span>
    +            <span class="n">top_k_predictions</span><span class="p">,</span>
    +            <span class="n">dist_sync_on_step</span><span class="p">,</span>
    +            <span class="n">accumulate_on_cpu</span><span class="p">,</span>
    +        <span class="p">)</span></div>
    +
    +
    +<div class="viewcode-block" id="DetectionMetrics_075"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics_075">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics_075</span><span class="p">(</span><span class="n">DetectionMetrics</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    +        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    +        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    +    <span class="p">):</span>
    +
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    +            <span class="n">num_cls</span><span class="p">,</span> <span class="n">post_prediction_callback</span><span class="p">,</span> <span class="n">normalize_targets</span><span class="p">,</span> <span class="mf">0.75</span><span class="p">,</span> <span class="n">recall_thres</span><span class="p">,</span> <span class="n">score_thres</span><span class="p">,</span> <span class="n">top_k_predictions</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="p">,<
    +        <span class="p">)</span></div>
    +
    +
    +<div class="viewcode-block" id="DetectionMetrics_050_095"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics_050_095">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics_050_095</span><span class="p">(</span><span class="n">DetectionMetrics</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    +        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    +        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    +    <span class="p">):</span>
    +
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    +            <span class="n">num_cls</span><span class="p">,</span>
    +            <span class="n">post_prediction_callback</span><span class="p">,</span>
    +            <span class="n">normalize_targets</span><span class="p">,</span>
    +            <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05_TO_095</span><span class="p">,</span>
    +            <span class="n">recall_thres</span><span class="p">,</span>
    +            <span class="n">score_thres</span><span class="p">,</span>
    +            <span class="n">top_k_predictions</span><span class="p">,</span>
    +            <span class="n">dist_sync_on_step</span><span class="p">,</span>
    +            <span class="n">accumulate_on_cpu</span><span class="p">,</span>
    +        <span class="p">)</span></div>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -256,4 +363,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.metrics.metric_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.metrics.metric_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.metrics.metric_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    84. <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">MetricCollection</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">AverageMeter</span>
    86. <div class="viewcode-block" id="get_logging_values"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_logging_values">[docs]</a><span class="k">def</span> <span class="nf">get_logging_values</span><span class="p">(</span><span class="n">loss_loggings</span><span class="p">:</span> <span class="n">AverageMeter</span><span class="p">,</span> <span class="n">metrics</span><span class="p">:</span> <span class="n">MetricCollection</span><span class="p">,</span> <span class="n">criterion</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    87. <span class="sd">&quot;&quot;&quot;</span>
    88. <span class="sd"> @param loss_loggings: AverageMeter running average for the loss items</span>
    89. <span class="sd"> @param metrics: MetricCollection object for running user specified metrics</span>
    90. <span class="sd"> @param criterion the object loss_loggings average meter is monitoring, when set to None- only the metrics values are</span>
    91. <span class="sd"> computed and returned.</span>
    92. <span class="sd"> @return: tuple of the computed values</span>
    93. <span class="sd"> &quot;&quot;&quot;</span>
    94. <span class="k">if</span> <span class="n">criterion</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    95. <span class="n">loss_loggingg_avg</span> <span class="o">=</span> <span class="n">loss_loggings</span><span class="o">.</span><span class="n">average</span>
    96. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">loss_loggingg_avg</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
    97. <span class="n">loss_loggingg_avg</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">loss_loggingg_avg</span><span class="p">])</span>
    98. <span class="n">logging_vals</span> <span class="o">=</span> <span class="n">loss_loggingg_avg</span> <span class="o">+</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span>
    99. <span class="k">else</span><span class="p">:</span>
    100. <span class="n">logging_vals</span> <span class="o">=</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span>
    101. <span class="k">return</span> <span class="n">logging_vals</span></div>
    102. <div class="viewcode-block" id="get_metrics_titles"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_metrics_titles">[docs]</a><span class="k">def</span> <span class="nf">get_metrics_titles</span><span class="p">(</span><span class="n">metrics_collection</span><span class="p">:</span> <span class="n">MetricCollection</span><span class="p">):</span>
    103. <span class="sd">&quot;&quot;&quot;</span>
    104. <span class="sd"> @param metrics_collection: MetricCollection object for running user specified metrics</span>
    105. <span class="sd"> @return: list of all the names of the computed values list(str)</span>
    106. <span class="sd"> &quot;&quot;&quot;</span>
    107. <span class="n">titles</span> <span class="o">=</span> <span class="p">[]</span>
    108. <span class="k">for</span> <span class="n">metric_name</span><span class="p">,</span> <span class="n">metric</span> <span class="ow">in</span> <span class="n">metrics_collection</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    109. <span class="k">if</span> <span class="n">metric_name</span> <span class="o">==</span> <span class="s2">&quot;additional_items&quot;</span><span class="p">:</span>
    110. <span class="k">continue</span>
    111. <span class="k">elif</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">metric</span><span class="p">,</span> <span class="s2">&quot;component_names&quot;</span><span class="p">):</span>
    112. <span class="n">titles</span> <span class="o">+=</span> <span class="n">metric</span><span class="o">.</span><span class="n">component_names</span>
    113. <span class="k">else</span><span class="p">:</span>
    114. <span class="n">titles</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">metric_name</span><span class="p">)</span>
    115. <span class="k">return</span> <span class="n">titles</span></div>
    116. <div class="viewcode-block" id="get_metrics_results_tuple"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_metrics_results_tuple">[docs]</a><span class="k">def</span> <span class="nf">get_metrics_results_tuple</span><span class="p">(</span><span class="n">metrics_collection</span><span class="p">:</span> <span class="n">MetricCollection</span><span class="p">):</span>
    117. <span class="sd">&quot;&quot;&quot;</span>
    118. <span class="sd"> @param metrics_collection: metrics collection of the user specified metrics</span>
    119. <span class="sd"> @type metrics_collection</span>
    120. <span class="sd"> @return: tuple of metrics values</span>
    121. <span class="sd"> &quot;&quot;&quot;</span>
    122. <span class="k">if</span> <span class="n">metrics_collection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    123. <span class="n">results_tuple</span> <span class="o">=</span> <span class="p">()</span>
    124. <span class="k">else</span><span class="p">:</span>
    125. <span class="n">results_tuple</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">flatten_metrics_dict</span><span class="p">(</span><span class="n">metrics_collection</span><span class="o">.</span><span class="n">compute</span><span class="p">())</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
    126. <span class="k">return</span> <span class="n">results_tuple</span></div>
    127. <div class="viewcode-block" id="flatten_metrics_dict"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.flatten_metrics_dict">[docs]</a><span class="k">def</span> <span class="nf">flatten_metrics_dict</span><span class="p">(</span><span class="n">metrics_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    128. <span class="sd">&quot;&quot;&quot;</span>
    129. <span class="sd"> :param metrics_dict - dictionary of metric values where values can also be dictionaries containing subvalues</span>
    130. <span class="sd"> (in the case of compound metrics)</span>
    131. <span class="sd"> @return: flattened dict of metric values i.e {metric1_name: metric1_value...}</span>
    132. <span class="sd"> &quot;&quot;&quot;</span>
    133. <span class="n">flattened</span> <span class="o">=</span> <span class="p">{}</span>
    134. <span class="k">for</span> <span class="n">metric_name</span><span class="p">,</span> <span class="n">metric_val</span> <span class="ow">in</span> <span class="n">metrics_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    135. <span class="k">if</span> <span class="n">metric_name</span> <span class="o">==</span> <span class="s2">&quot;additional_items&quot;</span><span class="p">:</span>
    136. <span class="k">continue</span>
    137. <span class="c1"># COLLECT ALL OF THE COMPONENTS IN THE CASE OF COMPOUND METRICS</span>
    138. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">metric_val</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
    139. <span class="k">for</span> <span class="n">sub_metric_name</span><span class="p">,</span> <span class="n">sub_metric_val</span> <span class="ow">in</span> <span class="n">metric_val</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    140. <span class="n">flattened</span><span class="p">[</span><span class="n">sub_metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">sub_metric_val</span>
    141. <span class="k">else</span><span class="p">:</span>
    142. <span class="n">flattened</span><span class="p">[</span><span class="n">metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">metric_val</span>
    143. <span class="k">return</span> <span class="n">flattened</span></div>
    144. <div class="viewcode-block" id="get_metrics_dict"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_metrics_dict">[docs]</a><span class="k">def</span> <span class="nf">get_metrics_dict</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">,</span> <span class="n">metrics_collection</span><span class="p">,</span> <span class="n">loss_logging_item_names</span><span class="p">):</span>
    145. <span class="sd">&quot;&quot;&quot;</span>
    146. <span class="sd"> Returns a dictionary with the epoch results as values and their names as keys.</span>
    147. <span class="sd"> @param metrics_tuple: the result tuple</span>
    148. <span class="sd"> @param metrics_collection: MetricsCollection</span>
    149. <span class="sd"> @param loss_logging_item_names: loss component&#39;s names.</span>
    150. <span class="sd"> @return: dict</span>
    151. <span class="sd"> &quot;&quot;&quot;</span>
    152. <span class="n">keys</span> <span class="o">=</span> <span class="n">loss_logging_item_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="n">metrics_collection</span><span class="p">)</span>
    153. <span class="n">metrics_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">)))</span>
    154. <span class="k">return</span> <span class="n">metrics_dict</span></div>
    155. <div class="viewcode-block" id="get_train_loop_description_dict"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_train_loop_description_dict">[docs]</a><span class="k">def</span> <span class="nf">get_train_loop_description_dict</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">,</span> <span class="n">metrics_collection</span><span class="p">,</span> <span class="n">loss_logging_item_names</span><span class="p">,</span> <span class="o">**</span><span class="n">log_items</span><span class="p">):</span>
    156. <span class="sd">&quot;&quot;&quot;</span>
    157. <span class="sd"> Returns a dictionary with the epoch&#39;s logging items as values and their names as keys, with the purpose of</span>
    158. <span class="sd"> passing it as a description to tqdm&#39;s progress bar.</span>
    159. <span class="sd"> @param metrics_tuple: the result tuple</span>
    160. <span class="sd"> @param metrics_collection: MetricsCollection</span>
    161. <span class="sd"> @param loss_logging_item_names: loss component&#39;s names.</span>
    162. <span class="sd"> @param log_items additional logging items to be rendered.</span>
    163. <span class="sd"> @return: dict</span>
    164. <span class="sd"> &quot;&quot;&quot;</span>
    165. <span class="n">log_items</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_metrics_dict</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">,</span> <span class="n">metrics_collection</span><span class="p">,</span> <span class="n">loss_logging_item_names</span><span class="p">))</span>
    166. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">log_items</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    167. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    168. <span class="n">log_items</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
    169. <span class="k">return</span> <span class="n">log_items</span></div>
    170. </pre></div>
    171. </div>
    172. </div>
    173. <footer>
    174. <hr/>
    175. <div role="contentinfo">
    176. <p>&#169; Copyright 2021, SuperGradients team.</p>
    177. </div>
    178. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    179. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    180. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    181. </footer>
    182. </div>
    183. </div>
    184. </section>
    185. </div>
    186. <script>
    187. jQuery(function () {
    188. SphinxRtdTheme.Navigation.enable(true);
    189. });
    190. </script>
    191. </body>
    192. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.metrics.segmentation_metrics &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.metrics.segmentation_metrics &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -95,7 +97,7 @@
     <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
     <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
     
     
     
     
    -<div class="viewcode-block" id="batch_pix_accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.batch_pix_accuracy">[docs]</a><span class="k">def</span> <span class="nf">batch_pix_accuracy</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">batch_pix_accuracy</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Batch Pixel Accuracy</span>
         <span class="sd">&quot;&quot;&quot;Batch Pixel Accuracy</span>
     <span class="sd">    Args:</span>
     <span class="sd">    Args:</span>
     <span class="sd">        predict: input 4D tensor</span>
     <span class="sd">        predict: input 4D tensor</span>
    @@ -106,12 +108,11 @@
         <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
         <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
         <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
         <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
         <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">predict</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
         <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">predict</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
    -    <span class="k">assert</span> <span class="n">pixel_correct</span> <span class="o">&lt;=</span> <span class="n">pixel_labeled</span><span class="p">,</span> \
    -        <span class="s2">&quot;Correct area should be smaller than Labeled&quot;</span>
    -    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span></div>
    +    <span class="k">assert</span> <span class="n">pixel_correct</span> <span class="o">&lt;=</span> <span class="n">pixel_labeled</span><span class="p">,</span> <span class="s2">&quot;Correct area should be smaller than Labeled&quot;</span>
    +    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span>
     
     
     
     
    -<div class="viewcode-block" id="batch_intersection_union"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.batch_intersection_union">[docs]</a><span class="k">def</span> <span class="nf">batch_intersection_union</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">nclass</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">batch_intersection_union</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">nclass</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Batch Intersection of Union</span>
         <span class="sd">&quot;&quot;&quot;Batch Intersection of Union</span>
     <span class="sd">    Args:</span>
     <span class="sd">    Args:</span>
     <span class="sd">        predict: input 4D tensor</span>
     <span class="sd">        predict: input 4D tensor</span>
    @@ -132,13 +133,12 @@
         <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</s
         <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</s
         <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</spa
         <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</spa
         <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
         <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
    -    <span class="k">assert</span> <span class="p">(</span><span class="n">area_inter</span> <span class="o">&lt;=</span> <span class="n">area_union</span><span class="p">)</span><span class="o">.</span><span class="n">all</span><span class="p">(),</span> \
    -        <span class="s2">&quot;Intersection area should be smaller than Union area&quot;</span>
    -    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span></div>
    +    <span class="k">assert</span> <span class="p">(</span><span class="n">area_inter</span> <span class="o">&lt;=</span> <span class="n">area_union</span><span class="p">)</span><span class="o">.</span><span class="n">all</span><span class="p">(),</span> <span class="s2">&quot;Intersection area should be smaller than Union area&quot;</span>
    +    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span>
     
     
     
     
     <span class="c1"># ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py</span>
     <span class="c1"># ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py</span>
    -<div class="viewcode-block" id="pixel_accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.pixel_accuracy">[docs]</a><span class="k">def</span> <span class="nf">pixel_accuracy</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">pixel_accuracy</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">):</span>
         <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
         <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
         <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
         <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
     
     
    @@ -147,7 +147,7 @@
         <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
         <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
         <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
         <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
         <span class="c1"># pixel_accuracy = 1.0 * pixel_correct / pixel_labeled</span>
         <span class="c1"># pixel_accuracy = 1.0 * pixel_correct / pixel_labeled</span>
    -    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span></div>
    +    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span>
     
     
     
     
     <span class="k">def</span> <span class="nf">_dice_from_confmat</span><span class="p">(</span>
     <span class="k">def</span> <span class="nf">_dice_from_confmat</span><span class="p">(</span>
    @@ -188,51 +188,48 @@
             <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
             <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
                 <span class="p">[</span>
                 <span class="p">[</span>
                     <span class="n">scores</span><span class="p">[:</span><span class="n">ignore_index</span><span class="p">],</span>
                     <span class="n">scores</span><span class="p">[:</span><span class="n">ignore_index</span><span class="p">],</span>
    -                <span class="n">scores</span><span class="p">[</span><span class="n">ignore_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:],</span>
    +                <span class="n">scores</span><span class="p">[</span><span class="n">ignore_index</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">:],</span>
                 <span class="p">]</span>
                 <span class="p">]</span>
             <span class="p">)</span>
             <span class="p">)</span>
     
     
         <span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
         <span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="intersection_and_union"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.intersection_and_union">[docs]</a><span class="k">def</span> <span class="nf">intersection_and_union</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">,</span> <span class="n">num_class</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">intersection_and_union</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">,</span> <span class="n">num_class</span><span class="p">):</span>
         <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
         <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
         <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
         <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
         <span class="c1"># Remove classes from unlabeled pixels in gt image.</span>
         <span class="c1"># Remove classes from unlabeled pixels in gt image.</span>
         <span class="n">im_pred</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
         <span class="n">im_pred</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
         <span class="c1"># Compute area intersection:</span>
         <span class="c1"># Compute area intersection:</span>
         <span class="n">intersection</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span>
         <span class="n">intersection</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span>
    -    <span class="n">area_inter</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">intersection</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
    -                                 <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
    +    <span class="n">area_inter</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">intersection</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span clas
         <span class="c1"># Compute area union:</span>
         <span class="c1"># Compute area union:</span>
    -    <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
    -                                <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
    -    <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_lab</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
    -                               <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
    +    <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">
    +    <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_lab</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(<
         <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
         <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
    -    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span></div>
    +    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span>
     
     
     
     
    -<div class="viewcode-block" id="AbstractMetricsArgsPrepFn"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.AbstractMetricsArgsPrepFn">[docs]</a><span class="k">class</span> <span class="nc">AbstractMetricsArgsPrepFn</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
    +<span class="k">class</span> <span class="nc">AbstractMetricsArgsPrepFn</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Abstract preprocess metrics arguments class.</span>
     <span class="sd">    Abstract preprocess metrics arguments class.</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    +
         <span class="nd">@abstractmethod</span>
         <span class="nd">@abstractmethod</span>
         <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="
         <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        All base classes must implement this function and return a tuple of torch tensors (predictions, target).</span>
     <span class="sd">        All base classes must implement this function and return a tuple of torch tensors (predictions, target).</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span></div>
    +        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
     
     
     
     
    -<div class="viewcode-block" id="PreprocessSegmentationMetricsArgs"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PreprocessSegmentationMetricsArgs">[docs]</a><span class="k">class</span> <span class="nc">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">):</span>
    +<div class="viewcode-block" id="PreprocessSegmentationMetricsArgs"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PreprocessSegmentationMetricsArgs">[docs]</a><span class="k">class</span> <span class="nc">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and</span>
     <span class="sd">    Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and</span>
     <span class="sd">    apply normalizations.</span>
     <span class="sd">    apply normalizations.</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    -                 <span class="n">apply_arg_max</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -                 <span class="n">apply_sigmoid</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    +
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">apply_arg_max</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">apply_sigmoid</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        :param apply_arg_max: Whether to apply argmax on predictions tensor.</span>
     <span class="sd">        :param apply_arg_max: Whether to apply argmax on predictions tensor.</span>
     <span class="sd">        :param apply_sigmoid:  Whether to apply sigmoid on predictions tensor.</span>
     <span class="sd">        :param apply_sigmoid:  Whether to apply sigmoid on predictions tensor.</span>
    @@ -253,18 +250,16 @@
             <span class="k">return</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span></div>
             <span class="k">return</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span></div>
     
     
     
     
    -<div class="viewcode-block" id="PixelAccuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PixelAccuracy">[docs]</a><span class="k">class</span> <span class="nc">PixelAccuracy</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    -                 <span class="n">ignore_label</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span>
    -                 <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    -                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +<div class="viewcode-block" id="PixelAccuracy"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PixelAccuracy">[docs]</a><span class="k">class</span> <span class="nc">PixelAccuracy</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ignore_label</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span 
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span> <span class="o">=</span> <span class="n">ignore_label</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span> <span class="o">=</span> <span class="n">ignore_label</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</spa
    -        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_label&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</sp
    +        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_label&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span
             <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
     
     
    -<div class="viewcode-block" id="PixelAccuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PixelAccuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <s
    +<div class="viewcode-block" id="PixelAccuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PixelAccuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span clas
             <span class="n">predict</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
             <span class="n">predict</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
     
     
             <span class="n">labeled_mask</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span><span class="p">)</span>
             <span class="n">labeled_mask</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span><span class="p">)</span>
    @@ -273,81 +268,117 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span> <span class="o">+=</span> <span class="n">pixel_correct</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span> <span class="o">+=</span> <span class="n">pixel_correct</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span> <span class="o">+=</span> <span class="n">pixel_labeled</span></div>
             <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span> <span class="o">+=</span> <span class="n">pixel_labeled</span></div>
     
     
    -<div class="viewcode-block" id="PixelAccuracy.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PixelAccuracy.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    -        <span class="n">_total_correct</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s1">&#39;int64&#39;<
    -        <span class="n">_total_label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s1">&#39;int64&#39;</spa
    +<div class="viewcode-block" id="PixelAccuracy.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PixelAccuracy.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="n">_total_correct</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;int64&quot
    +        <span class="n">_total_label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;int64&quot;</s
             <span class="n">pix_acc</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">_total_correct</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">spacing</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n"
             <span class="n">pix_acc</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">_total_correct</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">spacing</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n"
             <span class="k">return</span> <span class="n">pix_acc</span></div></div>
             <span class="k">return</span> <span class="n">pix_acc</span></div></div>
     
     
     
     
    -<div class="viewcode-block" id="IoU"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.IoU">[docs]</a><span class="k">class</span> <span class="nc">IoU</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    -                 <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    -                 <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
    -                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    -                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
    -                         <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">)</span>
    +<div class="viewcode-block" id="IoU"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.IoU">[docs]</a><span class="k">class</span> <span class="nc">IoU</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
    +        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    +        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +    <span class="p">):</span>
    +
    +        <span class="k">if</span> <span class="n">num_classes</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;IoU class only for multi-class usage! For binary usage, please call </span><span class="si">{</span><span class="n">BinaryIOU</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    +
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
             <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
     
     
    -<div class="viewcode-block" id="IoU.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.IoU.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor
    +<div class="viewcode-block" id="IoU.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.IoU.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
             <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
             <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
     
     
     
     
    -<div class="viewcode-block" id="Dice"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Dice">[docs]</a><span class="k">class</span> <span class="nc">Dice</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    -                 <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    -                 <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
    -                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    -                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
    -                         <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">)</span>
    +<div class="viewcode-block" id="Dice"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Dice">[docs]</a><span class="k">class</span> <span class="nc">Dice</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
    +        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    +        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +    <span class="p">):</span>
    +
    +        <span class="k">if</span> <span class="n">num_classes</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Dice class only for multi-class usage! For binary usage, please call </span><span class="si">{</span><span class="n">BinaryDice</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    +
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
             <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
     
     
    -<div class="viewcode-block" id="Dice.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Dice.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tens
    +<div class="viewcode-block" id="Dice.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Dice.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span
             <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
             <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div>
     
     
    -<div class="viewcode-block" id="Dice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Dice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    +<div class="viewcode-block" id="Dice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Dice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;Computes Dice coefficient&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;Computes Dice coefficient&quot;&quot;&quot;</span>
    -        <span class="k">return</span> <span class="n">_dice_from_confmat</span><span class="p">(</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">confmat</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">absent_score</span><span class="p">,</span> <span class="bp">self</span><span class="o">.
    -        <span class="p">)</span></div></div>
    +        <span class="k">return</span> <span class="n">_dice_from_confmat</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">confmat</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n
     
     
     
     
    -<div class="viewcode-block" id="BinaryIOU"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryIOU">[docs]</a><span class="k">class</span> <span class="nc">BinaryIOU</span><span class="p">(</span><span class="n">IoU</span><span class="p">):</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    -                 <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    -                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    -                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +<div class="viewcode-block" id="BinaryIOU"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryIOU">[docs]</a><span class="k">class</span> <span class="nc">BinaryIOU</span><span class="p">(</span><span class="n">IoU</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    +        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    +        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +    <span class="p">):</span>
             <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
    -                         <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;target_IOU&quot;</span><span class="p">,</span> <span class="s2">&quot;background_IOU&quot;</span><span class="p">,</span> <span class="s2">&quot;mean_IOU&quot;</span><span class="p">]</span>
    -
    -<div class="viewcode-block" id="BinaryIOU.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryIOU.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    +            <span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
    +            <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span>
    +            <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
    +            <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span>
    +            <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
    +            <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">,</span>
    +        <span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span> <span class="o">=</span> <span class="p">{</span>
    +            <span class="s2">&quot;target_IOU&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +            <span class="s2">&quot;background_IOU&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +            <span class="s2">&quot;mean_IOU&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +        <span class="p">}</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
    +
    +<div class="viewcode-block" id="BinaryIOU.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryIOU.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="n">ious</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">BinaryIOU</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
             <span class="n">ious</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">BinaryIOU</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
             <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_IOU&quot;</span><span class="p">:</span> <span class="n">io
             <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_IOU&quot;</span><span class="p">:</span> <span class="n">io
     
     
     
     
    -<div class="viewcode-block" id="BinaryDice"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryDice">[docs]</a><span class="k">class</span> <span class="nc">BinaryDice</span><span class="p">(</span><span class="n">Dice</span><span class="p">):</span>
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    -                 <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    -                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    -                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +<div class="viewcode-block" id="BinaryDice"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryDice">[docs]</a><span class="k">class</span> <span class="nc">BinaryDice</span><span class="p">(</span><span class="n">Dice</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    +        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
    +        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +    <span class="p">):</span>
             <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
    -                         <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;target_Dice&quot;</span><span class="p">,</span> <span class="s2">&quot;background_Dice&quot;</span><span class="p">,</span> <span class="s2">&quot;mean_Dice&quot;</span><span class="p">]</span>
    -
    -<div class="viewcode-block" id="BinaryDice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryDice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
    +            <span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
    +            <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span>
    +            <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
    +            <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span>
    +            <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
    +            <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">,</span>
    +        <span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span> <span class="o">=</span> <span class="p">{</span>
    +            <span class="s2">&quot;target_Dice&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +            <span class="s2">&quot;background_Dice&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +            <span class="s2">&quot;mean_Dice&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    +        <span class="p">}</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
    +
    +<div class="viewcode-block" id="BinaryDice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryDice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="n">dices</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
             <span class="n">dices</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
             <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_Dice&quot;</span><span class="p">:</span> <span class="
             <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_Dice&quot;</span><span class="p">:</span> <span class="
     </pre></div>
     </pre></div>
    @@ -379,4 +410,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.models.sg_module &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.models.sg_module</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.models.sg_module</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
    84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    85. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
    86. <div class="viewcode-block" id="SgModule"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule">[docs]</a><span class="k">class</span> <span class="nc">SgModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    87. <div class="viewcode-block" id="SgModule.initialize_param_groups"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.initialize_param_groups">[docs]</a> <span class="k">def</span> <span class="nf">initialize_param_groups</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">training_params</span><span class="p">:</span> <span class="n">HpmStruct</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
    88. <span class="sd">&quot;&quot;&quot;</span>
    89. <span class="sd"> :return: list of dictionaries containing the key &#39;named_params&#39; with a list of named params</span>
    90. <span class="sd"> &quot;&quot;&quot;</span>
    91. <span class="k">return</span> <span class="p">[{</span><span class="s2">&quot;named_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()}]</span></div>
    92. <div class="viewcode-block" id="SgModule.update_param_groups"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.update_param_groups">[docs]</a> <span class="k">def</span> <span class="nf">update_param_groups</span><span class="p">(</span>
    93. <span class="bp">self</span><span class="p">,</span> <span class="n">param_groups</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">iter</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">training_params</span><span class="p">:</span> <span class="n">HpmStruct</span><span class="p">,</span> <span class="n">total_batch</span><span class="p">:</span> <span class="nb">int</span>
    94. <span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
    95. <span class="sd">&quot;&quot;&quot;</span>
    96. <span class="sd"> :param param_groups: list of dictionaries containing the params</span>
    97. <span class="sd"> :return: list of dictionaries containing the params</span>
    98. <span class="sd"> &quot;&quot;&quot;</span>
    99. <span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">param_groups</span><span class="p">:</span>
    100. <span class="n">param_group</span><span class="p">[</span><span class="s2">&quot;lr&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">lr</span>
    101. <span class="k">return</span> <span class="n">param_groups</span></div>
    102. <div class="viewcode-block" id="SgModule.get_include_attributes"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.get_include_attributes">[docs]</a> <span class="k">def</span> <span class="nf">get_include_attributes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
    103. <span class="sd">&quot;&quot;&quot;</span>
    104. <span class="sd"> This function is used by the EMA. When updating the EMA model, some attributes of the main model (used in training)</span>
    105. <span class="sd"> are updated to the EMA model along with the model weights.</span>
    106. <span class="sd"> By default, all attributes are updated except for private attributes (starting with &#39;_&#39;)</span>
    107. <span class="sd"> You can either set include_attributes or exclude_attributes. By returning a non empty list from this function,</span>
    108. <span class="sd"> you override the default behaviour and only attributes named in this list will be updated.</span>
    109. <span class="sd"> Note: This will also override the get_exclude_attributes list.</span>
    110. <span class="sd"> :return: list of attributes to update from main model to EMA model</span>
    111. <span class="sd"> &quot;&quot;&quot;</span>
    112. <span class="k">return</span> <span class="p">[]</span></div>
    113. <div class="viewcode-block" id="SgModule.get_exclude_attributes"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.get_exclude_attributes">[docs]</a> <span class="k">def</span> <span class="nf">get_exclude_attributes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
    114. <span class="sd">&quot;&quot;&quot;</span>
    115. <span class="sd"> This function is used by the EMA. When updating the EMA model, some attributes of the main model (used in training)</span>
    116. <span class="sd"> are updated to the EMA model along with the model weights.</span>
    117. <span class="sd"> By default, all attributes are updated except for private attributes (starting with &#39;_&#39;)</span>
    118. <span class="sd"> You can either set include_attributes or exclude_attributes. By returning a non empty list from this function,</span>
    119. <span class="sd"> you override the default behaviour and attributes named in this list will also be excluded from update.</span>
    120. <span class="sd"> Note: if get_include_attributes is not empty, it will override this list.</span>
    121. <span class="sd"> :return: list of attributes to not update from main model to EMA mode</span>
    122. <span class="sd"> &quot;&quot;&quot;</span>
    123. <span class="k">return</span> <span class="p">[]</span></div>
    124. <div class="viewcode-block" id="SgModule.prep_model_for_conversion"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.prep_model_for_conversion">[docs]</a> <span class="k">def</span> <span class="nf">prep_model_for_conversion</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    125. <span class="sd">&quot;&quot;&quot;</span>
    126. <span class="sd"> Prepare the model to be converted to ONNX or other frameworks.</span>
    127. <span class="sd"> Typically, this function will freeze the size of layers which is otherwise flexible, replace some modules</span>
    128. <span class="sd"> with convertible substitutes and remove all auxiliary or training related parts.</span>
    129. <span class="sd"> :param input_size: [H,W]</span>
    130. <span class="sd"> &quot;&quot;&quot;</span></div>
    131. <div class="viewcode-block" id="SgModule.replace_head"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.replace_head">[docs]</a> <span class="k">def</span> <span class="nf">replace_head</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    132. <span class="sd">&quot;&quot;&quot;</span>
    133. <span class="sd"> Replace final layer for pretrained models. Since this varies between architectures, we leave it to the inheriting</span>
    134. <span class="sd"> class to implement.</span>
    135. <span class="sd"> &quot;&quot;&quot;</span>
    136. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div></div>
    137. </pre></div>
    138. </div>
    139. </div>
    140. <footer>
    141. <hr/>
    142. <div role="contentinfo">
    143. <p>&#169; Copyright 2021, SuperGradients team.</p>
    144. </div>
    145. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    146. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    147. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    148. </footer>
    149. </div>
    150. </div>
    151. </section>
    152. </div>
    153. <script>
    154. jQuery(function () {
    155. SphinxRtdTheme.Navigation.enable(true);
    156. });
    157. </script>
    158. </body>
    159. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.params &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
    14. <script src="../../../_static/jquery.js"></script>
    15. <script src="../../../_static/underscore.js"></script>
    16. <script src="../../../_static/doctools.js"></script>
    17. <script src="../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html">SuperGradients</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.common.html">Common</a></li>
    40. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.training.html">Training</a></li>
    41. </ul>
    42. </div>
    43. </div>
    44. </nav>
    45. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    46. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    47. <a href="../../../index.html">SuperGradients</a>
    48. </nav>
    49. <div class="wy-nav-content">
    50. <div class="rst-content">
    51. <div role="navigation" aria-label="Page navigation">
    52. <ul class="wy-breadcrumbs">
    53. <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
    54. <li><a href="../../index.html">Module code</a> &raquo;</li>
    55. <li>super_gradients.training.params</li>
    56. <li class="wy-breadcrumbs-aside">
    57. </li>
    58. </ul>
    59. <hr/>
    60. </div>
    61. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    62. <div itemprop="articleBody">
    63. <h1>Source code for super_gradients.training.params</h1><div class="highlight"><pre>
    64. <span></span><span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
    65. <span class="n">DEFAULT_TRAINING_PARAMS</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;lr_warmup_epochs&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
    66. <span class="s2">&quot;cosine_final_lr_ratio&quot;</span><span class="p">:</span> <span class="mf">0.01</span><span class="p">,</span>
    67. <span class="s2">&quot;optimizer&quot;</span><span class="p">:</span> <span class="s2">&quot;SGD&quot;</span><span class="p">,</span>
    68. <span class="s2">&quot;criterion_params&quot;</span><span class="p">:</span> <span class="p">{},</span>
    69. <span class="s2">&quot;ema&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    70. <span class="s2">&quot;batch_accumulate&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="c1"># number of batches to accumulate before every backward pass</span>
    71. <span class="s2">&quot;ema_params&quot;</span><span class="p">:</span> <span class="p">{},</span>
    72. <span class="s2">&quot;zero_weight_decay_on_bias_and_bn&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    73. <span class="s2">&quot;load_opt_params&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    74. <span class="s2">&quot;run_validation_freq&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span>
    75. <span class="s2">&quot;save_model&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    76. <span class="s2">&quot;metric_to_watch&quot;</span><span class="p">:</span> <span class="s2">&quot;Accuracy&quot;</span><span class="p">,</span>
    77. <span class="s2">&quot;launch_tensorboard&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    78. <span class="s2">&quot;tb_files_user_prompt&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># Asks User for Tensorboard Deletion Prompt</span>
    79. <span class="s2">&quot;silent_mode&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># Silents the Print outs</span>
    80. <span class="s2">&quot;mixed_precision&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    81. <span class="s2">&quot;tensorboard_port&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    82. <span class="s2">&quot;save_ckpt_epoch_list&quot;</span><span class="p">:</span> <span class="p">[],</span> <span class="c1"># indices where the ckpt will save automatically</span>
    83. <span class="s2">&quot;average_best_models&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    84. <span class="s2">&quot;dataset_statistics&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># add a dataset statistical analysis and sample images to tensorboard</span>
    85. <span class="s2">&quot;save_tensorboard_to_s3&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    86. <span class="s2">&quot;lr_schedule_function&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    87. <span class="s2">&quot;train_metrics_list&quot;</span><span class="p">:</span> <span class="p">[],</span>
    88. <span class="s2">&quot;valid_metrics_list&quot;</span><span class="p">:</span> <span class="p">[],</span>
    89. <span class="s2">&quot;loss_logging_items_names&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;Loss&quot;</span><span class="p">],</span>
    90. <span class="s2">&quot;greater_metric_to_watch_is_better&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    91. <span class="s2">&quot;precise_bn&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    92. <span class="s2">&quot;precise_bn_batch_size&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    93. <span class="s2">&quot;seed&quot;</span><span class="p">:</span> <span class="mi">42</span><span class="p">,</span>
    94. <span class="s2">&quot;lr_mode&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    95. <span class="s2">&quot;phase_callbacks&quot;</span><span class="p">:</span> <span class="p">[],</span>
    96. <span class="s2">&quot;log_installed_packages&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
    97. <span class="s2">&quot;save_full_train_log&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    98. <span class="s2">&quot;sg_logger&quot;</span><span class="p">:</span> <span class="s2">&quot;base_sg_logger&quot;</span><span class="p">,</span>
    99. <span class="s2">&quot;sg_logger_params&quot;</span><span class="p">:</span>
    100. <span class="p">{</span><span class="s2">&quot;tb_files_user_prompt&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># Asks User for Tensorboard Deletion Prompt</span>
    101. <span class="s2">&quot;project_name&quot;</span><span class="p">:</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
    102. <span class="s2">&quot;launch_tensorboard&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
    103. <span class="s2">&quot;tensorboard_port&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    104. <span class="s2">&quot;save_checkpoints_remote&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># upload checkpoint files to s3</span>
    105. <span class="s2">&quot;save_tensorboard_remote&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># upload tensorboard files to s3</span>
    106. <span class="s2">&quot;save_logs_remote&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">},</span> <span class="c1"># upload log files to s3</span>
    107. <span class="s2">&quot;warmup_mode&quot;</span><span class="p">:</span> <span class="s2">&quot;linear_step&quot;</span>
    108. <span class="p">}</span>
    109. <span class="n">DEFAULT_OPTIMIZER_PARAMS_SGD</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">&quot;momentum&quot;</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">}</span>
    110. <span class="n">DEFAULT_OPTIMIZER_PARAMS_ADAM</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">}</span>
    111. <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROP</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">&quot;momentum&quot;</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">}</span>
    112. <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">&quot;momentum&quot;</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">}</span>
    113. <span class="n">TRAINING_PARAM_SCHEMA</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;object&quot;</span><span class="p">,</span>
    114. <span class="s2">&quot;properties&quot;</span><span class="p">:</span> <span class="p">{</span>
    115. <span class="s2">&quot;max_epochs&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;minimum&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">800</span><span class="p">},</span>
    116. <span class="c1"># FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH</span>
    117. <span class="c1"># &quot;lr_updates&quot;: {&quot;type&quot;: &quot;array&quot;, &quot;minItems&quot;: 1},</span>
    118. <span class="s2">&quot;lr_decay_factor&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;minimum&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">},</span>
    119. <span class="s2">&quot;lr_warmup_epochs&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;minimum&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">10</span><span class="p">},</span>
    120. <span class="s2">&quot;initial_lr&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;exclusiveMinimum&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">10</span><span class="p">}</span>
    121. <span class="p">},</span>
    122. <span class="s2">&quot;if&quot;</span><span class="p">:</span> <span class="p">{</span>
    123. <span class="s2">&quot;properties&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;lr_mode&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;const&quot;</span><span class="p">:</span> <span class="s2">&quot;step&quot;</span><span class="p">}}</span>
    124. <span class="p">},</span>
    125. <span class="s2">&quot;then&quot;</span><span class="p">:</span> <span class="p">{</span>
    126. <span class="s2">&quot;required&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;lr_updates&quot;</span><span class="p">,</span> <span class="s2">&quot;lr_decay_factor&quot;</span><span class="p">]</span>
    127. <span class="p">},</span>
    128. <span class="s2">&quot;required&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;max_epochs&quot;</span><span class="p">,</span> <span class="s2">&quot;lr_mode&quot;</span><span class="p">,</span> <span class="s2">&quot;initial_lr&quot;</span><span class="p">,</span> <span class="s2">&quot;loss&quot;</span><span class="p">]</span>
    129. <span class="p">}</span>
    130. <div class="viewcode-block" id="TrainingParams"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.params.TrainingParams">[docs]</a><span class="k">class</span> <span class="nc">TrainingParams</span><span class="p">(</span><span class="n">HpmStruct</span><span class="p">):</span>
    131. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
    132. <span class="c1"># WE initialize by the default training params, overridden by the provided params</span>
    133. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">DEFAULT_TRAINING_PARAMS</span><span class="p">)</span>
    134. <span class="bp">self</span><span class="o">.</span><span class="n">set_schema</span><span class="p">(</span><span class="n">TRAINING_PARAM_SCHEMA</span><span class="p">)</span>
    135. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">entries</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    136. <span class="bp">self</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">entries</span><span class="p">)</span>
    137. <div class="viewcode-block" id="TrainingParams.override"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.params.TrainingParams.override">[docs]</a> <span class="k">def</span> <span class="nf">override</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
    138. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">entries</span><span class="p">)</span>
    139. <span class="bp">self</span><span class="o">.</span><span class="n">validate</span><span class="p">()</span></div></div>
    140. </pre></div>
    141. </div>
    142. </div>
    143. <footer>
    144. <hr/>
    145. <div role="contentinfo">
    146. <p>&#169; Copyright 2021, SuperGradients team.</p>
    147. </div>
    148. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    149. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    150. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    151. </footer>
    152. </div>
    153. </div>
    154. </section>
    155. </div>
    156. <script>
    157. jQuery(function () {
    158. SphinxRtdTheme.Navigation.enable(true);
    159. });
    160. </script>
    161. </body>
    162. </html>
    Discard
    Only showing up to 1000 lines per file, please use a local Git client to see the full diff.
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.sg_model.sg_model &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.sg_trainer.sg_trainer &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -76,7 +78,7 @@
       <ul class="wy-breadcrumbs">
       <ul class="wy-breadcrumbs">
           <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
           <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
               <li><a href="../../../index.html">Module code</a> &raquo;</li>
               <li><a href="../../../index.html">Module code</a> &raquo;</li>
    -      <li>super_gradients.training.sg_model.sg_model</li>
    +      <li>super_gradients.training.sg_trainer.sg_trainer</li>
           <li class="wy-breadcrumbs-aside">
           <li class="wy-breadcrumbs-aside">
           </li>
           </li>
       </ul>
       </ul>
    @@ -85,74 +87,99 @@
               <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
               <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
                <div itemprop="articleBody">
                <div itemprop="articleBody">
                  
                  
    -  <h1>Source code for super_gradients.training.sg_model.sg_model</h1><div class="highlight"><pre>
    +  <h1>Source code for super_gradients.training.sg_trainer.sg_trainer</h1><div class="highlight"><pre>
     <span></span><span class="kn">import</span> <span class="nn">inspect</span>
     <span></span><span class="kn">import</span> <span class="nn">inspect</span>
     <span class="kn">import</span> <span class="nn">os</span>
     <span class="kn">import</span> <span class="nn">os</span>
    -<span class="kn">import</span> <span class="nn">sys</span>
     <span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
     <span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
    -<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Any</span>
    +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">Dict</span>
    +<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
     
     
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
     <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    -<span class="kn">import</span> <span class="nn">pkg_resources</span>
     <span class="kn">import</span> <span class="nn">torch</span>
     <span class="kn">import</span> <span class="nn">torch</span>
    -<span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transforms</span>
    -<span class="kn">from</span> <span class="nn">deprecated</span> <span class="kn">import</span> <span class="n">deprecated</span>
    +<span class="kn">import</span> <span class="nn">hydra</span>
    +<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">DictConfig</span>
     <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
     <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    -<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">DistributedSampler</span>
    +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">SequentialSampler</span>
     <span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">GradScaler</span><span class="p">,</span> <span class="n">autocast</span>
     <span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">GradScaler</span><span class="p">,</span> <span class="n">autocast</span>
     <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">MetricCollection</span>
     <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">MetricCollection</span>
     <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
     <span class="kn">from</span> <span class="nn">piptools.scripts.sync</span> <span class="kn">import</span> <span class="n">_get_installed_distributions</span>
     <span class="kn">from</span> <span class="nn">piptools.scripts.sync</span> <span class="kn">import</span> <span class="n">_get_installed_distributions</span>
     
     
    +<span class="kn">from</span> <span class="nn">torch.utils.data.distributed</span> <span class="kn">import</span> <span class="n">DistributedSampler</span>
    +
    +<span class="kn">from</span> <span class="nn">super_gradients.common.factories.type_factory</span> <span class="kn">import</span> <span class="n">TypeFactory</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers</span> <span class="kn">import</span> <span class="n">InfiniteSampler</span><span class="p">,</span> <span class="n">RepeatAugSampler</span>
    +
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.callbacks_factory</span> <span class="kn">import</span> <span class="n">CallbacksFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.callbacks_factory</span> <span class="kn">import</span> <span class="n">CallbacksFactory</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.common.data_types.enum</span> <span class="kn">import</span> <span class="n">MultiGPUMode</span><span class="p">,</span> <span class="n">StrictLoad</span><span class="p">,</span> <span class="n">EvaluationType</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">ARCHITECTURES</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">ARCHITECTURES</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">env_helpers</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">env_helpers</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.common.factories.datasets_factory</span> <span class="kn">import</span> <span class="n">DatasetsFactory</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span><span class="p">,</span> <span class="n">mute_current_process</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.list_factory</span> <span class="kn">import</span> <span class="n">ListFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.list_factory</span> <span class="kn">import</span> <span class="n">ListFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.losses_factory</span> <span class="kn">import</span> <span class="n">LossesFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.losses_factory</span> <span class="kn">import</span> <span class="n">LossesFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.metrics_factory</span> <span class="kn">import</span> <span class="n">MetricsFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.factories.metrics_factory</span> <span class="kn">import</span> <span class="n">MetricsFactory</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers</span> <span class="kn">import</span> <span class="n">SG_LOGGERS</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers</span> <span class="kn">import</span> <span class="n">SG_LOGGERS</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.abstract_sg_logger</span> <span class="kn">import</span> <span class="n">AbstractSGLogger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.abstract_sg_logger</span> <span class="kn">import</span> <span class="n">AbstractSGLogger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.base_sg_logger</span> <span class="kn">import</span> <span class="n">BaseSGLogger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.base_sg_logger</span> <span class="kn">import</span> <span class="n">BaseSGLogger</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span><span class="p">,</span> <span class="n">models</span><span class="p">,</span> <span class="n">dataloaders</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">sg_model_utils</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils.quantization_utils</span> <span class="kn">import</span> <span class="n">QATCallback</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils.sg_model_utils</span> <span class="kn">import</span> <span class="n">MonitoredValue</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">metrics</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.sg_model_exceptions</span> <span class="kn">import</span> <span class="n">UnsupportedOptimizerFormat</span><span class="p">,</span> \
    -    <span class="n">IllegalDataloaderInitialization</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.datasets</span> <span class="kn">import</span> <span class="n">DatasetInterface</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">sg_trainer_utils</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils.sg_trainer_utils</span> <span class="kn">import</span> <span class="n">MonitoredValue</span><span class="p">,</span> <span class="n">parse_args</span><span class="p">,</span> <span class="n">log_main_training_params</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.sg_trainer_exceptions</span> <span class="kn">import</span> <span class="n">UnsupportedOptimizerFormat</span><span class="p">,</span> <span class="n">GPUModeNotSetupError</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses</span> <span class="kn">import</span> <span class="n">LOSSES</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.losses</span> <span class="kn">import</span> <span class="n">LOSSES</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.metrics.metric_utils</span> <span class="kn">import</span> <span class="n">get_metrics_titles</span><span class="p">,</span> <span class="n">get_metrics_results_tuple</span><span class="p">,</span> \
    -    <span class="n">get_logging_values</span><span class="p">,</span> \
    -    <span class="n">get_metrics_dict</span><span class="p">,</span> <span class="n">get_train_loop_description_dict</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.metrics.metric_utils</span> <span class="kn">import</span> <span class="p">(</span>
    +    <span class="n">get_metrics_titles</span><span class="p">,</span>
    +    <span class="n">get_metrics_results_tuple</span><span class="p">,</span>
    +    <span class="n">get_logging_values</span><span class="p">,</span>
    +    <span class="n">get_metrics_dict</span><span class="p">,</span>
    +    <span class="n">get_train_loop_description_dict</span><span class="p">,</span>
    +<span class="p">)</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.params</span> <span class="kn">import</span> <span class="n">TrainingParams</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.params</span> <span class="kn">import</span> <span class="n">TrainingParams</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionPostPredictionCallback</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="n">MultiGPUModeAutocastWrapper</span><span class="p">,</span> \
    -    <span class="n">reduce_results_tuple_for_ddp</span><span class="p">,</span> <span class="n">compute_precise_bn_stats</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="p">(</span>
    +    <span class="n">MultiGPUModeAutocastWrapper</span><span class="p">,</span>
    +    <span class="n">reduce_results_tuple_for_ddp</span><span class="p">,</span>
    +    <span class="n">compute_precise_bn_stats</span><span class="p">,</span>
    +    <span class="n">setup_device</span><span class="p">,</span>
    +    <span class="n">require_gpu_setup</span><span class="p">,</span>
    +    <span class="n">get_gpu_mem_utilization</span><span class="p">,</span>
    +    <span class="n">get_world_size</span><span class="p">,</span>
    +    <span class="n">get_local_rank</span><span class="p">,</span>
    +    <span class="n">wait_for_the_master</span><span class="p">,</span>
    +<span class="p">)</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">ModelEMA</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">ModelEMA</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.optimizer_utils</span> <span class="kn">import</span> <span class="n">build_optimizer</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.optimizer_utils</span> <span class="kn">import</span> <span class="n">build_optimizer</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.weight_averaging_utils</span> <span class="kn">import</span> <span class="n">ModelWeightAveraging</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils.weight_averaging_utils</span> <span class="kn">import</span> <span class="n">ModelWeightAveraging</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.metrics</span> <span class="kn">import</span> <span class="n">Accuracy</span><span class="p">,</span> <span class="n">Top5</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.metrics</span> <span class="kn">import</span> <span class="n">Accuracy</span><span class="p">,</span> <span class="n">Top5</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">random_seed</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">random_seed</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="n">get_ckpt_local_path</span><span class="p">,</span> <span class="n">read_ckpt_state_dict</span><span class="p">,</span> \
    -    <span class="n">load_checkpoint_to_model</span><span class="p">,</span> <span class="n">load_pretrained_weights</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="p">(</span>
    +    <span class="n">get_ckpt_local_path</span><span class="p">,</span>
    +    <span class="n">read_ckpt_state_dict</span><span class="p">,</span>
    +    <span class="n">load_checkpoint_to_model</span><span class="p">,</span>
    +    <span class="n">load_pretrained_weights</span><span class="p">,</span>
    +    <span class="n">get_checkpoints_dir_path</span><span class="p">,</span>
    +<span class="p">)</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_utils</span> <span class="kn">import</span> <span class="n">DatasetStatisticsTensorboardLogger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_utils</span> <span class="kn">import</span> <span class="n">DatasetStatisticsTensorboardLogger</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">CallbackHandler</span><span class="p">,</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">LR_SCHEDULERS_CLS_DICT</span><span class="p">,</span> <span class="n">PhaseContext</span><span class="p">,</span> \
    -    <span class="n">MetricsUpdateCallback</span><span class="p">,</span> <span class="n">LR_WARMUP_CLS_DICT</span><span class="p">,</span> <span class="n">ContextSgMethods</span><span class="p">,</span> <span class="n">LRCallbackBase</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="p">(</span>
    +    <span class="n">CallbackHandler</span><span class="p">,</span>
    +    <span class="n">Phase</span><span class="p">,</span>
    +    <span class="n">LR_SCHEDULERS_CLS_DICT</span><span class="p">,</span>
    +    <span class="n">PhaseContext</span><span class="p">,</span>
    +    <span class="n">MetricsUpdateCallback</span><span class="p">,</span>
    +    <span class="n">LR_WARMUP_CLS_DICT</span><span class="p">,</span>
    +    <span class="n">ContextSgMethods</span><span class="p">,</span>
    +    <span class="n">LRCallbackBase</span><span class="p">,</span>
    +<span class="p">)</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
    -<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers.infinite_sampler</span> <span class="kn">import</span> <span class="n">InfiniteSampler</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.training.utils.hydra_utils</span> <span class="kn">import</span> <span class="n">load_experiment_cfg</span><span class="p">,</span> <span class="n">add_params_to_cfg</span>
    +<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">OmegaConf</span>
     
     
    -<span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">StrictLoad</span><span class="p">,</span> <span class="n">MultiGPUMode</span><span class="p">,</span> <span class="n">EvaluationType</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="SgModel"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel">[docs]</a><span class="k">class</span> <span class="nc">SgModel</span><span class="p">:</span>
    +<div class="viewcode-block" id="Trainer"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer">[docs]</a><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    SuperGradient Model - Base Class for Sg Models</span>
     <span class="sd">    SuperGradient Model - Base Class for Sg Models</span>
     
     
    @@ -168,30 +195,17 @@
     <span class="sd">        returns the test loss, accuracy and runtime</span>
     <span class="sd">        returns the test loss, accuracy and runtime</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
    -    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="p">:</span> <span class="n">Union</span
    -                 <span class="n">model_checkpoints_location</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;local&#39;</span><span class="p">,</span>
    -                 <span class="n">overwrite_local_checkpoint</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;ckpt_latest.pth&#39;</span><span class="p">,</span>
    -                 <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    -                 <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">Non
    -                 <span class="n">classes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="p">:</span> <span class="n">Union</span
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     
     
     <span class="sd">        :param experiment_name:                      Used for logging and loading purposes</span>
     <span class="sd">        :param experiment_name:                      Used for logging and loading purposes</span>
     <span class="sd">        :param device:                          If equal to &#39;cpu&#39; runs on the CPU otherwise on GPU</span>
     <span class="sd">        :param device:                          If equal to &#39;cpu&#39; runs on the CPU otherwise on GPU</span>
     <span class="sd">        :param multi_gpu:                       If True, runs on all available devices</span>
     <span class="sd">        :param multi_gpu:                       If True, runs on all available devices</span>
    -<span class="sd">        :param model_checkpoints_location:      If set to &#39;s3&#39; saves the Checkpoints in AWS S3</span>
     <span class="sd">                                                otherwise saves the Checkpoints Locally</span>
     <span class="sd">                                                otherwise saves the Checkpoints Locally</span>
    -<span class="sd">        :param overwrite_local_checkpoint:      If set to False keeps the current local checkpoint when importing</span>
     <span class="sd">                                                checkpoint from cloud service, otherwise overwrites the local checkpoints file</span>
     <span class="sd">                                                checkpoint from cloud service, otherwise overwrites the local checkpoints file</span>
    -<span class="sd">        :param ckpt_name:                       The Checkpoint to Load</span>
     <span class="sd">        :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will</span>
     <span class="sd">        :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will</span>
     <span class="sd">                                                reside. When none is give, it is assumed that</span>
     <span class="sd">                                                reside. When none is give, it is assumed that</span>
     <span class="sd">                                                pkg_resources.resource_filename(&#39;checkpoints&#39;, &quot;&quot;) exists and will be used.</span>
     <span class="sd">                                                pkg_resources.resource_filename(&#39;checkpoints&#39;, &quot;&quot;) exists and will be used.</span>
    -<span class="sd">        :param train_loader:                    Training set Dataloader instead of using DatasetInterface, must pass &quot;valid_loader&quot;</span>
    -<span class="sd">                                                and &quot;classes&quot; along with it</span>
    -<span class="sd">        :param valid_loader:                    Validation set Dataloader</span>
    -<span class="sd">        :param test_loader:                     Test set Dataloader</span>
    -<span class="sd">        :param classes:                         List of class labels</span>
     
     
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="c1"># SET THE EMPTY PROPERTIES</span>
             <span class="c1"># SET THE EMPTY PROPERTIES</span>
    @@ -201,7 +215,6 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="kc">None</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="o">=</span> <span class="kc">None</span>
    @@ -217,157 +230,189 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span> <span class="o">=</span> <span class="kc">False</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span> <span class="o">=</span> <span class="kc">False</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span> <span class="o">=</span> <span class="kc">None</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span> <span class="o">=</span> <span class="s1">&#39;average_model.pth&#39;</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span> <span class="o">=</span> <span class="s2">&quot;average_model.pth&quot;</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">start_epoch</span> <span class="o">=</span> <span class="mi">0</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">start_epoch</span> <span class="o">=</span> <span class="mi">0</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">StrictLoad</span><span class="o">.</span><span class="n">ON</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">StrictLoad</span><span class="o">.</span><span class="n">ON</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="kc">False</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="kc">False</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span> <span class="o">=</span> <span class="s1">&#39;ckpt_best.pth&#39;</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span> <span class="o">=</span> <span class="s2">&quot;ckpt_best.pth&quot;</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">enable_qat</span> <span class="o">=</span> <span class="kc">False</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">enable_qat</span> <span class="o">=</span> <span class="kc">False</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">qat_params</span> <span class="o">=</span> <span class="p">{}</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">qat_params</span> <span class="o">=</span> <span class="p">{}</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_infinite_train_loader</span> <span class="o">=</span> <span class="kc">False</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_infinite_train_loader</span> <span class="o">=</span> <span class="kc">False</span>
    -
    -        <span class="c1"># DETERMINE THE LOCATION OF THE LOSS AND ACCURACY IN THE RESULTS TUPLE OUTPUTED BY THE TEST</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">loss_idx_in_results_tuple</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">acc_idx_in_results_tuple</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_first_backward</span> <span class="o">=</span> <span class="kc">True</span>
     
     
             <span class="c1"># METRICS</span>
             <span class="c1"># METRICS</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="kc">None</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="kc">None</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_train_metrics_is_better</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>  <span class="c1"># For each metric, indicates if greater is better</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">greater_valid_metrics_is_better</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
     
     
             <span class="c1"># SETTING THE PROPERTIES FROM THE CONSTRUCTOR</span>
             <span class="c1"># SETTING THE PROPERTIES FROM THE CONSTRUCTOR</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span> <span class="o">=</span> <span class="n">experiment_name</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span> <span class="o">=</span> <span class="n">experiment_name</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">ckpt_name</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">overwrite_local_checkpoint</span> <span class="o">=</span> <span class="n">overwrite_local_checkpoint</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">model_checkpoints_location</span> <span class="o">=</span> <span class="n">model_checkpoints_location</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">_set_dataset_properties</span><span class="p">(</span><span class="n">classes</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">)</span>
    -
    -        <span class="c1"># CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION</span>
    -        <span class="k">if</span> <span class="n">ckpt_root_dir</span><span class="p">:</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_root_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span><span class="p">)</span>
    -        <span class="k">elif</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_exists</span><span class="p">(</span><span class="s2">&quot;checkpoints&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">):</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span><span class="p">)</span>
    -        <span class="k">else</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal checkpoints directory: pass ckpt_root_dir that exists, or add &#39;checkpoints&#39; to&quot;</span>
    -                             <span class="s2">&quot;resources.&quot;</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="kc">None</span>
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">)</span>
     
     
             <span class="c1"># INITIALIZE THE DEVICE FOR THE MODEL</span>
             <span class="c1"># INITIALIZE THE DEVICE FOR THE MODEL</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_initialize_device</span><span class="p">(</span><span class="n">requested_device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requested_multi_gpu</span><span class="o">=</span><span class="n">multi_gpu</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_initialize_device</span><span class="p">(</span><span class="n">requested_device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requested_multi_gpu</span><span class="o">=</span><span class="n">multi_gpu</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
             <span class="c1"># SET THE DEFAULTS</span>
             <span class="c1"># SET THE DEFAULTS</span>
             <span class="c1"># TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK</span>
             <span class="c1"># TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK</span>
     
     
    -        <span class="n">default_results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;Train Loss&#39;</span><span class="p">,</span> <span class="s1">&#39;Train Acc&#39;</span><span class="p">,</span> <span class="s1">&#39;Train Top5&#39;</span><span class="p">,</span> <span class="s1">&#39;Valid Loss&#39;</span><span class="p">,</span> <span class="s1">&#39;Valid Acc&#39;</span><span class="p">,</span> <span class="s1">&#39;Valid Top5&#39;</span><span cla
    +        <span class="n">default_results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Train Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Train Acc&quot;</span><span class="p">,</span> <span class="s2">&quot;Train Top5&quot;</span><span class="p">,</span> <span class="s2">&quot;Valid Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Valid Acc&quot;</span><span class="p">,</span> <span class="s2">&quot;Valid Top5&quot;</sp
     
     
             <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="n">default_results_titles</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="n">default_results_titles</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">loss_idx_in_results_tuple</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">acc_idx_in_results_tuple</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
    -        <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span> <span class="o">=</span> <span class="n">MetricCollection</span><span class="p">([</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()]),</span> <span class="n">MetricCollection</span><span class="p">(</span>
    -            <span class="p">[</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()])</span>
    -
    -        <span class="n">default_loss_logging_items_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Loss&quot;</span><span class="p">]</span>
    +        <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span> <span class="o">=</span> <span class="n">MetricCollection</span><span class="p">([</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()]),</span> <span class="n">MetricCollection</span><span class="p">([</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()])</spa
     
     
             <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="n">default_loss_logging_items_names</span>
     
     
             <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
     
     
    -    <span class="k">def</span> <span class="nf">_set_dataset_properties</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">classes</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">):</span>
    -        <span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">,</span> <span class="n">classes</span><span class="p">])</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">([</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">,</span> <span cla
    -            <span class="k">raise</span> <span class="n">IllegalDataloaderInitialization</span><span class="p">()</span>
    +<div class="viewcode-block" id="Trainer.train_from_config"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.train_from_config">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">train_from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">cfg</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DictConfig</span><span class="p">,</span> <span class="nb">dict</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span 
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Trains according to cfg recipe configuration.</span>
     
     
    -        <span class="n">dataset_params</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;batch_size&quot;</span><span class="p">:</span> <span class="n">train_loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">if</span> <span class="n">train_loader</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
    -                          <span class="s2">&quot;val_batch_size&quot;</span><span class="p">:</span> <span class="n">valid_loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">if</span> <span class="n">valid_loader</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
    -                          <span class="s2">&quot;test_batch_size&quot;</span><span class="p">:</span> <span class="n">test_loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">if</span> <span class="n">test_loader</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
    -                          <span class="s2">&quot;dataset_dir&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
    -                          <span class="s2">&quot;s3_link&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">}</span>
    +<span class="sd">        @param cfg: The parsed DictConfig from yaml recipe files or a dictionary</span>
    +<span class="sd">        @return: the model and the output of trainer.train(...) (i.e results tuple)</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
     
     
    -        <span class="k">if</span> <span class="n">train_loader</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
    -            <span class="k">if</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">([</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">DistributedSampler</span><span class="p">),</span>
    -                        <span class="nb">isinstance</span><span class="p">(</span><span class="n">valid_loader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">DistributedSampler</span><span class="p">),</span>
    -                        <span class="n">test_loader</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">DistributedSampler</span><span class="p">)]):</span>
    -                <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;DDP training was selected but the dataloader samplers are not of type DistributedSamplers&quot;</span><span class="p">)</span>
    +        <span class="n">setup_device</span><span class="p">(</span><span class="n">multi_gpu</span><span class="o">=</span><span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="s2">&quot;multi_gpu&quot;</span><span class="p">,</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">OFF</span><span class="p">),</span> <span class="n">num_gpus</span>
    +
    +        <span class="c1"># INSTANTIATE ALL OBJECTS IN CFG</span>
    +        <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
    +
    +        <span class="n">kwargs</span> <span class="o">=</span> <span class="n">parse_args</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="fm">__init__</span><span class="p">)</span>
    +
    +        <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    +
    +        <span class="c1"># INSTANTIATE DATA LOADERS</span>
    +        <span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
    +            <span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">train_dataloader</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cf
    +        <span class="p">)</span>
    +
    +        <span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
    +            <span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</s
    +        <span class="p">)</span>
    +
    +        <span class="c1"># BUILD NETWORK</span>
    +        <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
    +            <span class="n">model_name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">architecture</span><span class="p">,</span>
    +            <span class="n">num_classes</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span>
    +            <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
    +            <span class="n">strict_load</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">strict_load</span><span class="p">,</span>
    +            <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
    +            <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">checkpoint_path</span><span class="p">,</span>
    +            <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">,</span>
    +        <span class="p">)</span>
    +        <span class="n">recipe_logged_cfg</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;recipe_config&quot;</span><span class="p">:</span> <span class="n">OmegaConf</span><span class="o">.</span><span class="n">to_container</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">resolve</span><span class="o">=</span><span class="kc">True</span><span class="p">)}</span>
    +        <span class="c1"># TRAIN</span>
    +        <span class="n">res</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">train</span><span class="p">(</span>
    +            <span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span>
    +            <span class="n">train_loader</span><span class="o">=</span><span class="n">train_dataloader</span><span class="p">,</span>
    +            <span class="n">valid_loader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span>
    +            <span class="n">training_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span><span class="p">,</span>
    +            <span class="n">additional_configs_to_log</span><span class="o">=</span><span class="n">recipe_logged_cfg</span><span class="p">,</span>
    +        <span class="p">)</span>
    +
    +        <span class="k">return</span> <span class="n">model</span><span class="p">,</span> <span class="n">res</span></div>
    +
    +<div class="viewcode-block" id="Trainer.resume_experiment"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.resume_experiment">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">resume_experiment</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class=
    +        <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">        Resume a training that was run using our recipes.</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o
    -            <span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">dataset_params</span><span class="p">),</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">,</span> <span class="n">classes</span>
    +<span class="sd">        :param experiment_name:     Name of the experiment to resume</span>
    +<span class="sd">        :param ckpt_root_dir:       Directory including the checkpoints</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Resume training using the checkpoint recipe, ignoring the current recipe&quot;</span><span class="p">)</span>
    +        <span class="n">cfg</span> <span class="o">=</span> <span class="n">load_experiment_cfg</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">)</span>
    +        <span class="n">add_params_to_cfg</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;training_hyperparams.resume=True&quot;</span><span class="p">])</span>
    +        <span class="bp">cls</span><span class="o">.</span><span class="n">train_from_config</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span></div>
     
     
    -<div class="viewcode-block" id="SgModel.connect_dataset_interface"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel.connect_dataset_interface">[docs]</a>    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;dataset_interface&#39;</span><span class="p">,</span> <span class="n">DatasetsFactory</span><span class="p">())</span>
    -    <span class="k">def</span> <span class="nf">connect_dataset_interface</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_interface</span><span class="p">:</span> <span class="n">DatasetInterface</span><span class="p">,</span> <span class="n">data_loader_num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">):</span>
    +<div class="viewcode-block" id="Trainer.evaluate_from_recipe"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.evaluate_from_recipe">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">evaluate_from_recipe</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">cfg</span><span class="p">:</span> <span class="n">DictConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        :param dataset_interface: DatasetInterface object</span>
    -<span class="sd">        :param data_loader_num_workers: The number of threads to initialize the Data Loaders with</span>
    -<span class="sd">            The dataset to be connected</span>
    +<span class="sd">        Evaluate according to a cfg recipe configuration.</span>
    +
    +<span class="sd">        Note:   This script does NOT run training, only validation.</span>
    +<span class="sd">                Please make sure that the config refers to a PRETRAINED MODEL either from one of your checkpoint or from pretrained weights from model zoo.</span>
    +<span class="sd">        :param cfg: The parsed DictConfig from yaml recipe files or a dictionary</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">:</span>
    -            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;Overriding the dataloaders that SgModel was initialized with&quot;</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span> <span class="o">=</span> <span class="n">dataset_interface</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> \
    -            <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span><span class="o">.</span><span class="n">get_data_loaders</span><span class="p">(</span><span class="n">batch_size_factor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_devices</span><span class="p">,</span>
    -                                                    <span class="n">num_workers</span><span class="o">=</span><span class="n">data_loader_num_workers</span><span class="p">,</span>
    -                                                    <span class="n">distributed_sampler</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span><span class="o">.</span><span class="n">get_dataset_params</span><span class="p">()</span></div>
    +        <span class="n">setup_device</span><span class="p">(</span><span class="n">multi_gpu</span><span class="o">=</span><span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="s2">&quot;multi_gpu&quot;</span><span class="p">,</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">OFF</span><span class="p">),</span> <span class="n">num_gpus</span>
     
     
    -    <span class="c1"># FIXME - we need to resolve flake8&#39;s &#39;function is too complex&#39; for this function</span>
    -<div class="viewcode-block" id="SgModel.build_model"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel.build_model">[docs]</a>    <span class="k">def</span> <span class="nf">build_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>  <span class="c1"># noqa: C901 - too complex</span>
    -                    <span class="n">architecture</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">],</span>
    -                    <span class="n">arch_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">checkpoint_params</span><span class="o">=</span><span class="p">{},</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    -        <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        :param architecture:               Defines the network&#39;s architecture from models/ALL_ARCHITECTURES</span>
    -<span class="sd">        :param arch_params:                Architecture H.P. e.g.: block, num_blocks, num_classes, etc.</span>
    -<span class="sd">        :param checkpoint_params:          Dictionary like object with the following key:values:</span>
    -
    -<span class="sd">            load_checkpoint:            Load a pre-trained checkpoint</span>
    -<span class="sd">            strict_load:                See StrictLoad class documentation for details.</span>
    -<span class="sd">            source_ckpt_folder_name:    folder name to load the checkpoint from (self.experiment_name if none is given)</span>
    -<span class="sd">            load_weights_only:          loads only the weight from the checkpoint and zeroize the training params</span>
    -<span class="sd">            load_backbone:              loads the provided checkpoint to self.net.backbone instead of self.net</span>
    -<span class="sd">            external_checkpoint_path:   The path to the external checkpoint to be loaded. Can be absolute or relative</span>
    -<span class="sd">                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to</span>
    -<span class="sd">                                               load the checkpoint even if the load_checkpoint flag is not provided.</span>
    +        <span class="c1"># INSTANTIATE ALL OBJECTS IN CFG</span>
    +        <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
     
     
    -<span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="k">if</span> <span class="s1">&#39;num_classes&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">arch_params</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    -            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    -                <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Error&#39;</span><span class="p">,</span> <span class="s1">&#39;Number of classes not defined in arch params and dataset is not defined&#39;</span><span class="p">)</span>
    -            <span class="k">else</span><span class="p">:</span>
    -                <span class="n">arch_params</span><span class="p">[</span><span class="s1">&#39;num_classes&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classes</span><span class="p">)</span>
    +        <span class="n">kwargs</span> <span class="o">=</span> <span class="n">parse_args</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="fm">__init__</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">arch_params</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">checkpoint_params</span><span class="p">)</span>
    +        <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instantiate_net</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">*</span><span class=
    +        <span class="c1"># INSTANTIATE DATA LOADERS</span>
    +        <span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
    +            <span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</s
    +        <span class="p">)</span>
     
     
    -        <span class="c1"># SAVE THE ARCHITECTURE FOR NEURAL ARCHITECTURE SEARCH</span>
    +        <span class="n">checkpoints_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">experiment_name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">ckpt
    +        <span class="n">checkpoint_path</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">checkpoints_dir</span> <span class="o">/</span> <span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span><span class="o">.</span><span class="n">ckpt_name</span><span class="p">)</span>
    +        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Evaluating checkpoint: </span><span class="si">{</span><span class="n">checkpoint_path</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">architecture</span> <span class="o">=</span> <span class="n">architecture</span>
    +        <span class="c1"># BUILD NETWORK</span>
    +        <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
    +            <span class="n">model_name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">architecture</span><span class="p">,</span>
    +            <span class="n">num_classes</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span>
    +            <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
    +            <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
    +            <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">checkpoint_path</span><span class="p">,</span>
    +            <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">,</span>
    +        <span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">_net_to_device</span><span class="p">()</span>
    +        <span class="c1"># TEST</span>
    +        <span class="n">val_results_tuple</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">test</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="n">test_loader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">test_metrics_list</span><span class="o">=</span><span class="n">cfg</span><span
     
     
    -        <span class="c1"># SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;update_param_groups&#39;</span><span class="p">)</span>
    +        <span class="n">valid_metrics_dict</span> <span class="o">=</span> <span class="n">get_metrics_dict</span><span class="p">(</span><span class="n">val_results_tuple</span><span class="p">,</span> <span class="n">trainer</span><span class="o">.</span><span class="n">test_metrics</span><span class="p">,</span> <span class="n">trainer</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">)</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span></div>
    +        <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Validate Results&quot;</span><span class="p">]</span>
    +        <span class="n">results</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s2">&quot;   - </span><span class="si">{</span><span class="n">metric</span><span class="si">:</span><span class="s2">10</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">metric</span><span class="p">,</span> <spa
    +        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">results</span><span class="p">))</span></div>
     
     
    -    <span class="k">def</span> <span class="nf">_set_ckpt_loading_attributes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +<div class="viewcode-block" id="Trainer.evaluate_checkpoint"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.evaluate_checkpoint">[docs]</a>    <span class="nd">@classmethod</span>
    +    <span class="k">def</span> <span class="nf">evaluate_checkpoint</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;ckpt_latest.pth&quot;</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class=
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    -<span class="sd">        Sets checkpoint loading related attributes according to self.checkpoint_params</span>
    +<span class="sd">        Evaluate a checkpoint resulting from one of your previous experiment, using the same parameters (dataset, valid_metrics,...)</span>
    +<span class="sd">        as used during the training of the experiment</span>
    +
    +<span class="sd">        Note:</span>
    +<span class="sd">            The parameters will be unchanged even if the recipe used for that experiment was changed since then.</span>
    +<span class="sd">            This is to ensure that validation of the experiment will remain exactly the same as during training.</span>
    +
    +<span class="sd">        Example, evaluate the checkpoint &quot;average_model.pth&quot; from experiment &quot;my_experiment_name&quot;:</span>
    +<span class="sd">            &gt;&gt; evaluate_checkpoint(experiment_name=&quot;my_experiment_name&quot;, ckpt_name=&quot;average_model.pth&quot;)</span>
    +
    +<span class="sd">        :param experiment_name:     Name of the experiment to validate</span>
    +<span class="sd">        :param ckpt_name:           Name of the checkpoint to test (&quot;ckpt_latest.pth&quot;, &quot;average_model.pth&quot; or &quot;ckpt_best.pth&quot; for instance)</span>
    +<span class="sd">        :param ckpt_root_dir:       Directory including the checkpoints</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint</span> <span class="o">=</span> <span class="p">{}</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;strict_load&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="n">Stric
    -        <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_ema_as_net&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="
    -        <span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;source_ckpt_folder_name&#39;</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_checkpoint&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="
    -        <span class="bp">self</span><span class="o">.</span><span class="n">load_backbone</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_backbone&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">
    -        <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;external_checkpoint_path&#39;</span><span class="p">)</span>
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span><span class="p">:</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">load_weights_only</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_weights_only&#39;</span><span class="p">,</span>
    -                                                          <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;ckpt_name&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="bp">self</sp
    +        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Evaluate checkpoint&quot;</span><span class="p">)</span>
    +        <span class="n">cfg</span> <span class="o">=</span> <span class="n">load_experiment_cfg</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">)</span>
    +        <span class="n">add_params_to_cfg</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;training_hyperparams.resume=True&quot;</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;ckpt_name=</span><span class="si">{</span><span class="n">ckpt_name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">])</span>
    +        <span class="bp">cls</span><span class="o">.</span><span class="n">evaluate_from_recipe</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span></div>
    +
    +    <span class="k">def</span> <span class="nf">_set_dataset_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="p">{</span>
    +            <span class="s2">&quot;train_dataset_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">dataset_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class
    +            <span class="s2">&quot;train_dataloader_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataloader_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="s2">&quot;dataloader_params&quot;</span><span
    +            <span class="s2">&quot;valid_dataset_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">dataset_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="o">.</span><span class
    +            <span class="s2">&quot;valid_dataloader_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="o">.</span><span class="n">dataloader_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">,</span> <span class="s2">&quot;dataloader_params&quot;</span><span
    +        <span class="p">}</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">)</span>
     
     
         <span class="k">def</span> <span class="nf">_net_to_device</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_net_to_device</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
    @@ -376,20 +421,17 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
     
     
             <span class="c1"># FOR MULTI-GPU TRAINING (not distributed)</span>
             <span class="c1"># FOR MULTI-GPU TRAINING (not distributed)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">sync_bn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span> <span class="s1">&#39;sync_bn&#39;</span><span class="p">,</span> <span class="n">default_val</span><
    +        <span class="n">sync_bn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;sync_bn&quot;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DATA_PARALLEL</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DATA_PARALLEL</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class
                 <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class
             <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
             <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
    -            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">sync_bn</span><span class="p">:</span>
    +            <span class="k">if</span> <span class="n">sync_bn</span><span class="p">:</span>
                     <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
                     <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
    -                    <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;DDP - Using Sync Batch Norm... Training time will be affected accordingly&#39;</span><span class="p">)</span>
    +                    <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;DDP - Using Sync Batch Norm... Training time will be affected accordingly&quot;</span><span class="p">)</span>
                     <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">SyncBatchNorm</span><span class="o">.</span><span class="n">convert_sync_batchnorm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span><span class="o">.</span><span class="n">to</spa
                     <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">SyncBatchNorm</span><span class="o">.</span><span class="n">convert_sync_batchnorm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span><span class="o">.</span><span class="n">to</spa
     
     
    -            <span class="n">local_rank</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;:&#39;</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">parallel</span><span class="o">.</span><span class="n">DistributedDataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span>
    -                                                                 <span class="n">device_ids</span><span class="o">=</span><span class="p">[</span><span class="n">local_rank</span><span class="p">],</span>
    -                                                                 <span class="n">output_device</span><span class="o">=</span><span class="n">local_rank</span><span class="p">,</span>
    -                                                                 <span class="n">find_unused_parameters</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +            <span class="n">local_rank</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;:&quot;</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">parallel</span><span class="o">.</span><span class="n">DistributedDataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</sp
     
     
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span>
    @@ -404,8 +446,7 @@
             <span class="c1"># SET THE MODEL IN training STATE</span>
             <span class="c1"># SET THE MODEL IN training STATE</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
             <span class="c1"># THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS</span>
             <span class="c1"># THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS</span>
    -        <span class="n">progress_bar_train_loader</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">bar_format</span><span class="o">=</span><span class="s2">&quot;</span><span class="si">{l_bar}{bar:10}{r_bar}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">dynamic_ncols</span><span class="o">=</span><span 
    -                                         <span class="n">disable</span><span class="o">=</span><span class="n">silent_mode</span><span class="p">)</span>
    +        <span class="n">progress_bar_train_loader</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">bar_format</span><span class="o">=</span><span class="s2">&quot;</span><span class="si">{l_bar}{bar:10}{r_bar}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">dynamic_ncols</span><span class="o">=</span><span 
             <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_description</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Train epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
             <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_description</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Train epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
     
     
             <span class="c1"># RESET/INIT THE METRIC LOGGERS</span>
             <span class="c1"># RESET/INIT THE METRIC LOGGERS</span>
    @@ -414,21 +455,23 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
             <span class="n">loss_avg_meter</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">AverageMeter</span><span class="p">()</span>
             <span class="n">loss_avg_meter</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">AverageMeter</span><span class="p">()</span>
     
     
    -        <span class="n">context</span> <span class="o">=</span> <span class="n">PhaseContext</span><span class="p">(</span><span class="n">epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
    -                               <span class="n">optimizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span>
    -                               <span class="n">metrics_compute_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span>
    -                               <span class="n">loss_avg_meter</span><span class="o">=</span><span class="n">loss_avg_meter</span><span class="p">,</span>
    -                               <span class="n">criterion</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">criterion</span><span class="p">,</span>
    -                               <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
    -                               <span class="n">lr_warmup_epochs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span><span class="p">,</span>
    -                               <span class="n">sg_logger</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="p">,</span>
    -                               <span class="n">train_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span>
    -                               <span class="n">context_methods</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_context_methods</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">),</span>
    -                               <span class="n">ddp_silent_mode</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">)</span>
    +        <span class="n">context</span> <span class="o">=</span> <span class="n">PhaseContext</span><span class="p">(</span>
    +            <span class="n">epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
    +            <span class="n">optimizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span>
    +            <span class="n">metrics_compute_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span>
    +            <span class="n">loss_avg_meter</span><span class="o">=</span><span class="n">loss_avg_meter</span><span class="p">,</span>
    +            <span class="n">criterion</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">criterion</span><span class="p">,</span>
    +            <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
    +            <span class="n">lr_warmup_epochs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span><span class="p">,</span>
    +            <span class="n">sg_logger</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="p">,</span>
    +            <span class="n">train_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span>
    +            <span class="n">context_methods</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_context_methods</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">),</span>
    +            <span class="n">ddp_silent_mode</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">,</span>
    +        <span class="p">)</span>
     
     
             <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch_items</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">progress_bar_train_loader</span><span class="p">):</span>
             <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch_items</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">progress_bar_train_loader</span><span class="p">):</span>
                 <span class="n">batch_items</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">batch_items</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
                 <span class="n">batch_items</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">batch_items</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    -            <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="n">sg_model_utils</span><span class="o">.</span><span class="n">unpack_batch_items</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span>
    +            <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="n">sg_trainer_utils</span><span class="o">.</span><span class="n">unpack_batch_items</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span>
     
     
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                     <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
                     <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
    @@ -440,12 +483,7 @@
                     <span class="c1"># COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS</span>
                     <span class="c1"># COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS</span>
                     <span class="n">loss</span><span class="p">,</span> <span class="n">loss_log_items</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_losses</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
                     <span class="n">loss</span><span class="p">,</span> <span class="n">loss_log_items</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_losses</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
     
     
    -            <span class="n">context</span><span class="o">.</span><span class="n">update_context</span><span class="p">(</span><span class="n">batch_idx</span><span class="o">=</span><span class="n">batch_idx</span><span class="p">,</span>
    -                                   <span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span>
    -                                   <span class="n">preds</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span>
    -                                   <span class="n">target</span><span class="o">=</span><span class="n">targets</span><span class="p">,</span>
    -                                   <span class="n">loss_log_items</span><span class="o">=</span><span class="n">loss_log_items</span><span class="p">,</span>
    -                                   <span class="o">**</span><span class="n">additional_batch_items</span><span class="p">)</span>
    +            <span class="n">context</span><span class="o">.</span><span class="n">update_context</span><span class="p">(</span><span class="n">batch_idx</span><span class="o">=</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">preds</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span> <span class="n">target</span><span class="o">
     
     
                 <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
     
     
    @@ -457,13 +495,12 @@
     
     
                 <span class="c1"># COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.</span>
                 <span class="c1"># COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.</span>
                 <span class="n">logging_values</span> <span class="o">=</span> <span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">average</span> <span class="o">+</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">)</span>
                 <span class="n">logging_values</span> <span class="o">=</span> <span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">average</span> <span class="o">+</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">)</span>
    -            <span class="n">gpu_memory_utilization</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_cached</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1E9</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span cl
    +            <span class="n">gpu_memory_utilization</span> <span class="o">=</span> <span class="n">get_gpu_mem_utilization</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="mi">0</span>
     
     
                 <span class="c1"># RENDER METRICS PROGRESS</span>
                 <span class="c1"># RENDER METRICS PROGRESS</span>
    -            <span class="n">pbar_message_dict</span> <span class="o">=</span> <span class="n">get_train_loop_description_dict</span><span class="p">(</span><span class="n">logging_values</span><span class="p">,</span>
    -                                                                <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span>
    -                                                                <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">,</span>
    -                                                                <span class="n">gpu_mem</span><span class="o">=</span><span class="n">gpu_memory_utilization</span><span class="p">)</span>
    +            <span class="n">pbar_message_dict</span> <span class="o">=</span> <span class="n">get_train_loop_description_dict</span><span class="p">(</span>
    +                <span class="n">logging_values</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">,</span> <span class="n">gpu_mem</span><span class="o">=</span><span class="n">gpu_memory_utilization</span>
    +            <span class="p">)</span>
     
     
                 <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_postfix</span><span class="p">(</span><span class="o">**</span><span class="n">pbar_message_dict</span><span class="p">)</span>
                 <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_postfix</span><span class="p">(</span><span class="o">**</span><span class="n">pbar_message_dict</span><span class="p">)</span>
     
     
    @@ -475,8 +512,9 @@
             <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">upload</span><span class="p">()</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">upload</span><span class="p">()</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="n">sg_model_utils</span><span class="o">.</span><span class="n">update_monitored_values_dict</span><span class="p">(</span>
    -            <span class="n">monitored_values_dict</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">,</span> <span class="n">new_values_dict</span><span class="o">=</span><span class="n">pbar_message_dict</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="n">sg_trainer_utils</span><span class="o">.</span><span class="n">update_monitored_values_dict</span><span class="p">(</span>
    +            <span class="n">monitored_values_dict</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">,</span> <span class="n">new_values_dict</span><span class="o">=</span><span class="n">pbar_message_dict</span>
    +        <span class="p">)</span>
     
     
             <span class="k">return</span> <span class="n">logging_values</span>
             <span class="k">return</span> <span class="n">logging_values</span>
     
     
    @@ -489,12 +527,50 @@
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="n">loss_logging_items</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
                 <span class="n">loss_logging_items</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
     
     
    +        <span class="c1"># ON FIRST BACKWARD, DERRIVE THE LOGGING TITLES.</span>
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_first_backward</span><span class="p">:</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">_init_loss_logging_names</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">)</span>
    +            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span><span class="p">:</span>
    +                <span class="bp">self</span><span class="o">.</span><span class="n">_init_monitored_items</span><span class="p">()</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">_first_backward</span> <span class="o">=</span> <span class="kc">False</span>
    +
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">):</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">):</span>
    -            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Loss output length must match loss_logging_items_names. Got &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span>
    -                <span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">))</span> <span class="o">+</span> <span class="s1">&#39;, and &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">)))</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
    +                <span class="s2">&quot;Loss output length must match loss_logging_items_names. Got &quot;</span>
    +                <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">))</span>
    +                <span class="o">+</span> <span class="s2">&quot;, and &quot;</span>
    +                <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">))</span>
    +            <span class="p">)</span>
             <span class="c1"># RETURN AND THE LOSS LOGGING ITEMS COMPUTED DURING LOSS FORWARD PASS</span>
             <span class="c1"># RETURN AND THE LOSS LOGGING ITEMS COMPUTED DURING LOSS FORWARD PASS</span>
             <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">loss_logging_items</span>
             <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">loss_logging_items</span>
     
     
    +    <span class="k">def</span> <span class="nf">_init_monitored_items</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">))</span><span class="o">.</span><span class=
    +        <span class="c1"># Instantiate the values to monitor (loss/metric)</span>
    +        <span class="k">for</span> <span class="n">loss_name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">:</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">[</span><span class="n">loss_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">loss_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="kc">False</span><span class="p">)</spa
    +            <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span><span class="p">[</span><span class="n">loss_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">loss_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="kc">False</span><span class="p">)</spa
    +
    +        <span class="k">for</span> <span class="n">metric_name</span> <span class="ow">in</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">):</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">[</span><span class="n">metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">metric_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">self</span><span class="o">.</
    +
    +        <span class="k">for</span> <span class="n">metric_name</span> <span class="ow">in</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">):</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span><span class="p">[</span><span class="n">metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">metric_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">self</span><span class="o">.</
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Train_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p
    +            <span class="s2">&quot;Valid_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">)</span>
    +        <span class="p">]</span>
    +
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">average_best_models</span><span class="p">:</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span> <span class="o">=</span> <span class="n">ModelWeightAveraging</span><span class="p">(</span>
    +                <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span><span class="p">,</span>
    +                <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">,</span>
    +                <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span>
    +                <span class="n">metric_to_watch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span><span class="p">,</span>
    +                <span class="n">metric_idx</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">,</span>
    +                <span class="n">load_checkpoint</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">,</span>
    +            <span class="p">)</span>
    +
         <span class="k">def</span> <span class="nf">_backward_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span 
         <span class="k">def</span> <span class="nf">_backward_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span 
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Run backprop on the loss and perform a step</span>
     <span class="sd">        Run backprop on the loss and perform a step</span>
    @@ -527,44 +603,43 @@
                 <span class="c1"># RUN PHASE CALLBACKS</span>
                 <span class="c1"># RUN PHASE CALLBACKS</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
     
     
    -    <span class="k">def</span> <span class="nf">_save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">:</span> <span class
    -                         <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +    <span class="k">def</span> <span class="nf">_save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">:</span> <span class
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Save the current state dict as latest (always), best (if metric was improved), epoch# (if determined in training</span>
     <span class="sd">        Save the current state dict as latest (always), best (if metric was improved), epoch# (if determined in training</span>
     <span class="sd">        params)</span>
     <span class="sd">        params)</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="c1"># WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return</span>
             <span class="c1"># WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return</span>
             <span class="k">if</span> <span class="n">validation_results_tuple</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">validation_results_tuple</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s1">&#39;ckpt_latest_weights_only.pth&#39;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="bp">self</span><span clas
    -                                          <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span><span class="p">)</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s2">&quot;ckpt_latest_weights_only.pth&quot;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p">:</span> <span class="bp">self</span><span 
                 <span class="k">return</span>
                 <span class="k">return</span>
     
     
             <span class="c1"># COMPUTE THE CURRENT metric</span>
             <span class="c1"># COMPUTE THE CURRENT metric</span>
             <span class="c1"># IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST&#39;S INDICES</span>
             <span class="c1"># IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST&#39;S INDICES</span>
    -        <span class="n">metric</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> \
    -            <span class="nb">sum</span><span class="p">([</span><span class="n">validation_results_tuple</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">])</span>
    +        <span class="n">metric</span> <span class="o">=</span> <span class="p">(</span>
    +            <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">]</span>
    +            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
    +            <span class="k">else</span> <span class="nb">sum</span><span class="p">([</span><span class="n">validation_results_tuple</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">])</span>
    +        <span class="p">)</span>
     
     
             <span class="c1"># BUILD THE state_dict</span>
             <span class="c1"># BUILD THE state_dict</span>
    -        <span class="n">state</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s1">&#39;acc&#39;</span><span class="p">:</span> <span class="n">metric</span><span class="p">,</span> <span class="s1">&#39;epoch&#39;</span><span class="p">:</span> <span cla
    +        <span class="n">state</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s2">&quot;acc&quot;</span><span class="p">:</span> <span class="n">metric</span><span class="p">,</span> <span class="s2">&quot;epoch&quot;</span><span class="p">:</span> <sp
             <span class="k">if</span> <span class="n">optimizer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">optimizer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="n">state</span><span class="p">[</span><span class="s1">&#39;optimizer_state_dict&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
    +            <span class="n">state</span><span class="p">[</span><span class="s2">&quot;optimizer_state_dict&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
     
     
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="n">state</span><span class="p">[</span><span class="s1">&#39;scaler_state_dict&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
    +            <span class="n">state</span><span class="p">[</span><span class="s2">&quot;scaler_state_dict&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
     
     
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="p">:</span>
    -            <span class="n">state</span><span class="p">[</span><span class="s1">&#39;ema_net&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
    +            <span class="n">state</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
             <span class="c1"># SAVES CURRENT MODEL AS ckpt_latest</span>
             <span class="c1"># SAVES CURRENT MODEL AS ckpt_latest</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s1">&#39;ckpt_latest.pth&#39;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span><s
    +        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s2">&quot;ckpt_latest.pth&quot;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span>
     
     
             <span class="c1"># SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list</span>
             <span class="c1"># SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list</span>
             <span class="k">if</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">save_ckpt_epoch_list</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">save_ckpt_epoch_list</span><span class="p">:</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;ckpt_epoch_</span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s1">.pth&#39;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span clas
    +            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;ckpt_epoch_</span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">.pth&quot;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span cl
     
     
             <span class="c1"># OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST</span>
             <span class="c1"># OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST</span>
    -        <span class="k">if</span> <span class="p">(</span><span class="n">metric</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span>
    -                <span class="n">metric</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">):</span>
    +        <span class="k">if</span> <span class="p">(</span><span class="n">metric</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">metric</span> <span class="o">&lt;</span> <span class="bp">self</spa
                 <span class="c1"># STORE THE CURRENT metric AS BEST</span>
                 <span class="c1"># STORE THE CURRENT metric AS BEST</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">metric</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">metric</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">_save_best_checkpoint</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">_save_best_checkpoint</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
    @@ -578,16 +653,50 @@
     
     
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">average_best_models</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">average_best_models</span><span class="p">:</span>
                 <span class="n">net_for_averaging</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span>
                 <span class="n">net_for_averaging</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span>
    -            <span class="n">averaged_model_sd</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span><span class="o">.</span><span class="n">get_average_model</span><span class="p">(</span><span class="n">net_for_averaging</span><span class="p">,</span>
    -                                                                              <span class="n">validation_results_tuple</span><span class="o">=</span><span class="n">validation_results_tuple</span><span class="p">)</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span><span class="p">,</span>
    -                                          <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="n">averaged_model_sd</span><span class="p">},</span> <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span><span class="p">)</span>
    +            <span class="n">averaged_model_sd</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span><span class="o">.</span><span class="n">get_average_model</span><span class="p">(</span><span class="n">net_for_averaging</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="o">=</span><span class="n">validation_results_tuple</span><span class="p">)</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p"
     
     
         <span class="k">def</span> <span class="nf">_save_best_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_save_best_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
             <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
     
     
    +    <span class="k">def</span> <span class="nf">_prep_net_for_train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">_init_arch_params</span><span class="p">()</span>
    +
    +        <span class="c1"># TODO: REMOVE THE BELOW LINE (FOR BACKWARD COMPATIBILITY)</span>
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="o">=</span> <span class="n">HpmStruct</span><span class="p">(</span><span class="n">load_checkpoint</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">resume</span><span class="p">)</span>
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_net_to_device</span><span class="p">()</span>
    +
    +        <span class="c1"># SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s2">&quot;update_param_groups&quot;</span><span class="p">)</span>
    +
    +        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint</span> <span class="o">=</span> <span class="p">{}</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;resume_strict_load&quot;</span><span class="p">,</span> <span class="n">StrictLoad</span><span class="o">.</span><span class="n"
    +        <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="kc">False</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;resume&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;resume_path&quot;</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;ckpt_name&quot;</span><span class="p">,</span> <span class="s2">&quot;ckpt_latest.pth&quot;</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span>
    +
    +    <span class="k">def</span> <span class="nf">_init_arch_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +        <span class="n">default_arch_params</span> <span class="o">=</span> <span class="n">HpmStruct</span><span class="p">()</span>
    +        <span class="n">arch_params</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="s2">&quot;arch_params&quot;</span><span class="p">,</span> <span class="n">default_arch_params</span><span class="p">)</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="o">=</span> <span class="n">default_arch_params</span>
    +        <span class="k">if</span> <span class="n">arch_params</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">arch_params</span><span class="o">.</span><span class="n">to_dict</span><span class="p">())</span>
    +
         <span class="c1"># FIXME - we need to resolve flake8&#39;s &#39;function is too complex&#39; for this function</span>
         <span class="c1"># FIXME - we need to resolve flake8&#39;s &#39;function is too complex&#39; for this function</span>
    -<div class="viewcode-block" id="SgModel.train"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel.train">[docs]</a>    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">training_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()):</span>  <span cla
    +<div class="viewcode-block" id="Trainer.train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.train">[docs]</a>    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span>
    +        <span class="bp">self</span><span class="p">,</span>
    +        <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span>
    +        <span class="n">training_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +        <span class="n">additional_configs_to_log</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +    <span class="p">):</span>  <span class="c1"># noqa: C901</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     
     
     <span class="sd">        train - Trains the Model</span>
     <span class="sd">        train - Trains the Model</span>
    @@ -596,8 +705,28 @@
     <span class="sd">          the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with</span>
     <span class="sd">          the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with</span>
     <span class="sd">          the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.</span>
     <span class="sd">          the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.</span>
     
     
    +<span class="sd">            :param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training&#39;s</span>
    +<span class="sd">                sg_logger. Format should be {&quot;Config_title_1&quot;: {...}, &quot;Config_title_2&quot;:{..}}.</span>
    +<span class="sd">            :param model: torch.nn.Module, model to train.</span>
     
     
    +<span class="sd">            :param train_loader: Dataloader for train set.</span>
    +<span class="sd">            :param valid_loader: Dataloader for validation.</span>
     <span class="sd">            :param training_params:</span>
     <span class="sd">            :param training_params:</span>
    +
    +<span class="sd">                - `resume` : bool (default=False)</span>
    +
    +<span class="sd">                    Whether to continue training from ckpt with the same experiment name</span>
    +<span class="sd">                     (i.e resume from CKPT_ROOT_DIR/EXPERIMENT_NAME/CKPT_NAME)</span>
    +
    +<span class="sd">                - `ckpt_name` : str (default=ckpt_latest.pth)</span>
    +
    +<span class="sd">                    The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and</span>
    +<span class="sd">                     resume_path=None</span>
    +
    +<span class="sd">                - `resume_path`: str (default=None)</span>
    +
    +<span class="sd">                    Explicit checkpoint path (.pth file) to use to resume training.</span>
    +
     <span class="sd">                - `max_epochs` : int</span>
     <span class="sd">                - `max_epochs` : int</span>
     
     
     <span class="sd">                    Number of epochs to run training.</span>
     <span class="sd">                    Number of epochs to run training.</span>
    @@ -659,10 +788,54 @@
     <span class="sd">                    where the computed loss is the sum of a few components we would like to log- these entries in</span>
     <span class="sd">                    where the computed loss is the sum of a few components we would like to log- these entries in</span>
     <span class="sd">                    loss_items).</span>
     <span class="sd">                    loss_items).</span>
     
     
    -<span class="sd">                    When training, set the loss_logging_items_names parameter in train_params to be a list of</span>
    -<span class="sd">                    strings, of length n_items who&#39;s ith element is the name of the ith entry in loss_items. Then</span>
    -<span class="sd">                    each item will be logged, rendered on tensorboard and &quot;watched&quot; (i.e saving model checkpoints</span>
    -<span class="sd">                    according to it).</span>
    +<span class="sd">                    IMPORTANT:When dealing with external loss classes, to logg/monitor the loss_items as described</span>
    +<span class="sd">                    above by specific string name:</span>
    +
    +<span class="sd">                    Set a &quot;component_names&quot; property in the loss class, whos instance is passed through train_params,</span>
    +<span class="sd">                     to be a list of strings, of length n_items who&#39;s ith element is the name of the ith entry in loss_items.</span>
    +<span class="sd">                     Then each item will be logged, rendered on tensorboard and &quot;watched&quot; (i.e saving model checkpoints</span>
    +<span class="sd">                     according to it) under &lt;LOSS_CLASS.__name__&gt;&quot;/&quot;&lt;COMPONENT_NAME&gt;. If a single item is returned rather then a</span>
    +<span class="sd">                     tuple, it would be logged under &lt;LOSS_CLASS.__name__&gt;. When there is no such attributed, the items</span>
    +<span class="sd">                     will be named &lt;LOSS_CLASS.__name__&gt;&quot;/&quot;Loss_&quot;&lt;IDX&gt; according to the length of loss_items</span>
    +
    +<span class="sd">                    For example:</span>
    +<span class="sd">                        class MyLoss(_Loss):</span>
    +<span class="sd">                            ...</span>
    +<span class="sd">                            def forward(self, inputs, targets):</span>
    +<span class="sd">                                ...</span>
    +<span class="sd">                                total_loss = comp1 + comp2</span>
    +<span class="sd">                                loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()</span>
    +<span class="sd">                                return total_loss, loss_items</span>
    +<span class="sd">                            ...</span>
    +<span class="sd">                            @property</span>
    +<span class="sd">                            def component_names(self):</span>
    +<span class="sd">                                return [&quot;total_loss&quot;, &quot;my_1st_component&quot;, &quot;my_2nd_component&quot;]</span>
    +
    +<span class="sd">                    Trainer.train(...</span>
    +<span class="sd">                                    train_params={&quot;loss&quot;:MyLoss(),</span>
    +<span class="sd">                                                    ...</span>
    +<span class="sd">                                                    &quot;metric_to_watch&quot;: &quot;MyLoss/my_1st_component&quot;}</span>
    +
    +<span class="sd">                        This will write to log and monitor MyLoss/total_loss, MyLoss/my_1st_component,</span>
    +<span class="sd">                         MyLoss/my_2nd_component.</span>
    +
    +<span class="sd">                   For example:</span>
    +<span class="sd">                        class MyLoss2(_Loss):</span>
    +<span class="sd">                            ...</span>
    +<span class="sd">                            def forward(self, inputs, targets):</span>
    +<span class="sd">                                ...</span>
    +<span class="sd">                                total_loss = comp1 + comp2</span>
    +<span class="sd">                                loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()</span>
    +<span class="sd">                                return total_loss, loss_items</span>
    +<span class="sd">                            ...</span>
    +
    +<span class="sd">                    Trainer.train(...</span>
    +<span class="sd">                                    train_params={&quot;loss&quot;:MyLoss(),</span>
    +<span class="sd">                                                    ...</span>
    +<span class="sd">                                                    &quot;metric_to_watch&quot;: &quot;MyLoss2/loss_0&quot;}</span>
    +
    +<span class="sd">                        This will write to log and monitor MyLoss2/loss_0, MyLoss2/loss_1, MyLoss2/loss_2</span>
    +<span class="sd">                        as they have been named by their positional index in loss_items.</span>
     
     
     <span class="sd">                    Since running logs will save the loss_items in some internal state, it is recommended that</span>
     <span class="sd">                    Since running logs will save the loss_items in some internal state, it is recommended that</span>
     <span class="sd">                    loss_items are detached from their computational graph for memory efficiency.</span>
     <span class="sd">                    loss_items are detached from their computational graph for memory efficiency.</span>
    @@ -711,7 +884,7 @@
     <span class="sd">                        is a list referring to the names of each entry in the output metric (torch tensor of size n)</span>
     <span class="sd">                        is a list referring to the names of each entry in the output metric (torch tensor of size n)</span>
     
     
     <span class="sd">                        one of &quot;loss_logging_items_names&quot; i.e which will correspond to an item returned during the</span>
     <span class="sd">                        one of &quot;loss_logging_items_names&quot; i.e which will correspond to an item returned during the</span>
    -<span class="sd">                        loss function&#39;s forward pass.</span>
    +<span class="sd">                        loss function&#39;s forward pass (see loss docs abov).</span>
     
     
     <span class="sd">                    At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint</span>
     <span class="sd">                    At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint</span>
     <span class="sd">                    is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth</span>
     <span class="sd">                    is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth</span>
    @@ -799,11 +972,6 @@
     <span class="sd">                    will be added to the tensorboard along with some sample images from the dataset. Currently only</span>
     <span class="sd">                    will be added to the tensorboard along with some sample images from the dataset. Currently only</span>
     <span class="sd">                    detection datasets are supported for analysis.</span>
     <span class="sd">                    detection datasets are supported for analysis.</span>
     
     
    -<span class="sd">                -  `save_full_train_log` : bool (default=False)</span>
    -
    -<span class="sd">                    When set, a full log (of all super_gradients modules, including uncaught exceptions from any other</span>
    -<span class="sd">                     module) of the training will be saved in the checkpoint directory under full_train_log.log</span>
    -
     <span class="sd">                -  `sg_logger` : Union[AbstractSGLogger, str] (defauls=base_sg_logger)</span>
     <span class="sd">                -  `sg_logger` : Union[AbstractSGLogger, str] (defauls=base_sg_logger)</span>
     
     
     <span class="sd">                    Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging</span>
     <span class="sd">                    Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging</span>
    @@ -855,49 +1023,49 @@
     
     
     <span class="sd">                        num_calib_batches: int, number of batches to collect the statistics from.</span>
     <span class="sd">                        num_calib_batches: int, number of batches to collect the statistics from.</span>
     
     
    -<span class="sd">                        percentile: float, percentile value to use when SgModel,quant_modules_calib_method=&#39;percentile&#39;.</span>
    +<span class="sd">                        percentile: float, percentile value to use when Trainer,quant_modules_calib_method=&#39;percentile&#39;.</span>
     <span class="sd">                         Discarded when other methods are used (Default=99.99).</span>
     <span class="sd">                         Discarded when other methods are used (Default=99.99).</span>
     
     
     
     
     <span class="sd">        :return:</span>
     <span class="sd">        :return:</span>
     <span class="sd">        &quot;&quot;&quot;</span>
     <span class="sd">        &quot;&quot;&quot;</span>
             <span class="k">global</span> <span class="n">logger</span>
             <span class="k">global</span> <span class="n">logger</span>
    +        <span class="k">if</span> <span class="n">training_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="n">training_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">train_loader</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="o">=</span> <span class="n">valid_loader</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_set_dataset_params</span><span class="p">()</span>
     
     
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Model&#39;</span><span class="p">,</span> <span class="s1">&#39;No model found&#39;</span><span class="p">)</span>
    -        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Data&#39;</span><span class="p">,</span> <span class="s1">&#39;No dataset found&#39;</span><span class="p">)</span>
    -
    +        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
    +            <span class="c1"># Note: the dataloader uses sampler of the batch_sampler when it is not None.</span>
    +            <span class="n">train_sampler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">batch_sampler</span><span class="o">.</span><span class="n">sampler</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">batch_sampler</span> <span class="ow">is</span> <span class="ow">not</
    +            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">train_sampler</span><span class="p">,</span> <span class="n">SequentialSampler</span><span class="p">):</span>
    +                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
    +                    <span class="s2">&quot;You are using a SequentialSampler on you training dataloader, while working on DDP. &quot;</span>
    +                    <span class="s2">&quot;This cancels the DDP benefits since it makes each process iterate through the entire dataset&quot;</span>
    +                <span class="p">)</span>
    +            <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">train_sampler</span><span class="p">,</span> <span class="p">(</span><span class="n">DistributedSampler</span><span class="p">,</span> <span class="n">InfiniteSampler</span><span class="p">,</span> <span class="n">RepeatAugSampler</span><span class="p">)):</span>
    +                <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
    +                    <span class="s2">&quot;The training sampler you are using might not support DDP. &quot;</span>
    +                    <span class="s2">&quot;If it doesnt, please use one of the following sampler: DistributedSampler, InfiniteSampler, RepeatAugSampler&quot;</span>
    +                <span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">TrainingParams</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">TrainingParams</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">training_params</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">training_params</span><span class="p">)</span>
     
     
    +        <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">model</span>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_prep_net_for_train</span><span class="p">()</span>
    +
             <span class="c1"># SET RANDOM SEED</span>
             <span class="c1"># SET RANDOM SEED</span>
    -        <span class="n">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">,</span>
    -                    <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
    +        <span class="n">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><sp
     
     
             <span class="n">silent_mode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">silent_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span>
             <span class="n">silent_mode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">silent_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span>
             <span class="c1"># METRICS</span>
             <span class="c1"># METRICS</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_set_train_metrics</span><span class="p">(</span><span class="n">train_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">train_metrics_list</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_set_train_metrics</span><span class="p">(</span><span class="n">train_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">train_metrics_list</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_set_valid_metrics</span><span class="p">(</span><span class="n">valid_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">valid_metrics_list</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_set_valid_metrics</span><span class="p">(</span><span class="n">valid_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">valid_metrics_list</span><span class="p">)</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">loss_logging_items_names</span>
    -
    -        <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Train_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span>
    -                               <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">)]</span> <span class="o">+</span> \
    -                              <span class="p">[</span><span class="s2">&quot;Valid_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span>
    -                               <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">)]</span>
     
     
             <span class="c1"># Store the metric to follow (loss\accuracy) and initialize as the worst value</span>
             <span class="c1"># Store the metric to follow (loss\accuracy) and initialize as the worst value</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">metric_to_watch</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">metric_to_watch</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span>
    -        <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span> <span class="o">=</span> <span class="p">(</span>
    -            <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">))</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span><span class="p">)<
    -
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.training_hyperparams.training_hyperparams &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../_static/jquery.js"></script>
    16. <script src="../../../../_static/underscore.js"></script>
    17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../_static/doctools.js"></script>
    19. <script src="../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.training.training_hyperparams.training_hyperparams</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.training.training_hyperparams.training_hyperparams</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">import</span> <span class="nn">hydra</span>
    86. <span class="kn">import</span> <span class="nn">pkg_resources</span>
    87. <span class="kn">from</span> <span class="nn">hydra</span> <span class="kn">import</span> <span class="n">compose</span><span class="p">,</span> <span class="n">initialize_config_dir</span>
    88. <span class="kn">from</span> <span class="nn">hydra.core.global_hydra</span> <span class="kn">import</span> <span class="n">GlobalHydra</span>
    89. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">override_default_params_without_nones</span>
    90. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.hydra_utils</span> <span class="kn">import</span> <span class="n">normalize_path</span>
    91. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    92. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
    93. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    94. <div class="viewcode-block" id="get"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.get">[docs]</a><span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="n">config_name</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">:</span>
    95. <span class="sd">&quot;&quot;&quot;</span>
    96. <span class="sd"> Class for creating training hyper parameters dictionary, taking defaults from yaml</span>
    97. <span class="sd"> files in src/super_gradients/recipes.</span>
    98. <span class="sd"> :param overriding_params: Dict, dictionary like object containing entries to override in the recipe&#39;s training</span>
    99. <span class="sd"> hyper parameters dictionary.</span>
    100. <span class="sd"> :param config_name: yaml config filename in recipes (for example coco2017_yolox).</span>
    101. <span class="sd"> &quot;&quot;&quot;</span>
    102. <span class="k">if</span> <span class="n">overriding_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    103. <span class="n">overriding_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
    104. <span class="n">GlobalHydra</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
    105. <span class="n">sg_recipes_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;super_gradients.recipes&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">)</span>
    106. <span class="k">with</span> <span class="n">initialize_config_dir</span><span class="p">(</span><span class="n">config_dir</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">sg_recipes_dir</span><span class="p">),</span> <span class="n">version_base</span><span class="o">=</span><span class="s2">&quot;1.2&quot;</span><span class="p">):</span>
    107. <span class="n">cfg</span> <span class="o">=</span> <span class="n">compose</span><span class="p">(</span><span class="n">config_name</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">config_name</span><span class="p">))</span>
    108. <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
    109. <span class="n">training_params</span> <span class="o">=</span> <span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span>
    110. <span class="n">training_params</span> <span class="o">=</span> <span class="n">override_default_params_without_nones</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">,</span> <span class="n">training_params</span><span class="p">)</span>
    111. <span class="k">return</span> <span class="n">training_params</span></div>
    112. <div class="viewcode-block" id="cifar10_resnet_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cifar10_resnet_train_params">[docs]</a><span class="k">def</span> <span class="nf">cifar10_resnet_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    113. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cifar10_resnet&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    114. <div class="viewcode-block" id="cityscapes_ddrnet_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_ddrnet_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_ddrnet_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    115. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_ddrnet&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    116. <div class="viewcode-block" id="cityscapes_regseg48_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_regseg48_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_regseg48_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    117. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_regseg48&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    118. <div class="viewcode-block" id="cityscapes_stdc_base_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_stdc_base_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_base_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    119. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_stdc_base&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    120. <div class="viewcode-block" id="cityscapes_stdc_seg50_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_stdc_seg50_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg50_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    121. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_stdc_seg50&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    122. <div class="viewcode-block" id="cityscapes_stdc_seg75_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_stdc_seg75_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg75_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    123. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_stdc_seg75&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    124. <div class="viewcode-block" id="coco2017_ssd_lite_mobilenet_v2_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.coco2017_ssd_lite_mobilenet_v2_train_params">[docs]</a><span class="k">def</span> <span class="nf">coco2017_ssd_lite_mobilenet_v2_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    125. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;coco2017_ssd_lite_mobilenet_v2&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    126. <div class="viewcode-block" id="coco2017_yolox_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.coco2017_yolox_train_params">[docs]</a><span class="k">def</span> <span class="nf">coco2017_yolox_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    127. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;coco2017_yolox&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    128. <div class="viewcode-block" id="coco_segmentation_shelfnet_lw_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.coco_segmentation_shelfnet_lw_train_params">[docs]</a><span class="k">def</span> <span class="nf">coco_segmentation_shelfnet_lw_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    129. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;coco_segmentation_shelfnet_lw&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    130. <div class="viewcode-block" id="imagenet_efficientnet_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_efficientnet_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_efficientnet_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    131. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_efficientnet&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    132. <div class="viewcode-block" id="imagenet_mobilenetv2_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv2_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv2_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    133. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv2&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    134. <div class="viewcode-block" id="imagenet_mobilenetv3_base_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv3_base_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_base_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    135. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv3_base&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    136. <div class="viewcode-block" id="imagenet_mobilenetv3_large_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv3_large_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_large_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    137. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv3_large&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    138. <div class="viewcode-block" id="imagenet_mobilenetv3_small_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv3_small_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_small_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    139. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv3_small&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    140. <div class="viewcode-block" id="imagenet_regnetY_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_regnetY_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_regnetY_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    141. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_regnetY&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    142. <div class="viewcode-block" id="imagenet_repvgg_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_repvgg_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_repvgg_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    143. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_repvgg&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    144. <div class="viewcode-block" id="imagenet_resnet50_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_resnet50_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    145. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_resnet50&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    146. <div class="viewcode-block" id="imagenet_resnet50_kd_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_resnet50_kd_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_kd_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    147. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_resnet50_kd&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    148. <div class="viewcode-block" id="imagenet_vit_base_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_vit_base_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_base_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    149. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_vit_base&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    150. <div class="viewcode-block" id="imagenet_vit_large_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_vit_large_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_large_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    151. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_vit_large&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
    152. </pre></div>
    153. </div>
    154. </div>
    155. <footer>
    156. <hr/>
    157. <div role="contentinfo">
    158. <p>&#169; Copyright 2021, SuperGradients team.</p>
    159. </div>
    160. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    161. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    162. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    163. </footer>
    164. </div>
    165. </div>
    166. </section>
    167. </div>
    168. <script>
    169. jQuery(function () {
    170. SphinxRtdTheme.Navigation.enable(true);
    171. });
    172. </script>
    173. </body>
    174. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    559
    560
    561
    562
    563
    564
    565
    566
    567
    568
    569
    570
    571
    572
    573
    574
    575
    576
    577
    578
    579
    580
    581
    582
    583
    584
    585
    586
    587
    588
    589
    590
    591
    592
    593
    594
    595
    596
    597
    598
    599
    600
    601
    602
    603
    604
    605
    606
    607
    608
    609
    610
    611
    612
    613
    614
    615
    616
    617
    618
    619
    620
    621
    622
    623
    624
    625
    626
    627
    628
    629
    630
    631
    632
    633
    634
    635
    636
    637
    638
    639
    640
    641
    642
    643
    644
    645
    646
    647
    648
    649
    650
    651
    652
    653
    654
    655
    656
    657
    658
    659
    660
    661
    662
    663
    664
    665
    666
    667
    668
    669
    670
    671
    672
    673
    674
    675
    676
    677
    678
    679
    680
    681
    682
    683
    684
    685
    686
    687
    688
    689
    690
    691
    692
    693
    694
    695
    696
    697
    698
    699
    700
    701
    702
    703
    704
    705
    706
    707
    708
    709
    710
    711
    712
    713
    714
    715
    716
    717
    718
    719
    720
    721
    722
    723
    724
    725
    726
    727
    728
    729
    730
    731
    732
    733
    734
    735
    736
    737
    738
    739
    740
    741
    742
    743
    744
    745
    746
    747
    748
    749
    750
    751
    752
    753
    754
    755
    756
    757
    758
    759
    760
    761
    762
    763
    764
    765
    766
    767
    768
    769
    770
    771
    772
    773
    774
    775
    776
    777
    778
    779
    780
    781
    782
    783
    784
    785
    786
    787
    788
    789
    790
    791
    792
    793
    794
    795
    796
    797
    798
    799
    800
    801
    802
    803
    804
    805
    806
    807
    808
    809
    810
    811
    812
    813
    814
    815
    816
    817
    818
    819
    820
    821
    822
    823
    824
    825
    826
    827
    828
    829
    830
    831
    832
    833
    834
    835
    836
    837
    838
    839
    840
    841
    842
    843
    844
    845
    846
    847
    848
    849
    850
    851
    852
    853
    854
    855
    856
    857
    858
    859
    860
    861
    862
    863
    864
    865
    866
    867
    868
    869
    870
    871
    872
    873
    874
    875
    876
    877
    878
    879
    880
    881
    882
    883
    884
    885
    886
    887
    888
    889
    890
    891
    892
    893
    894
    895
    896
    897
    898
    899
    900
    901
    902
    903
    904
    905
    906
    907
    908
    909
    910
    911
    912
    913
    914
    915
    916
    917
    918
    919
    920
    921
    922
    923
    924
    925
    926
    927
    928
    929
    930
    931
    932
    933
    934
    935
    936
    937
    938
    939
    940
    941
    942
    943
    944
    945
    946
    947
    948
    949
    950
    951
    952
    953
    954
    955
    956
    957
    958
    959
    960
    961
    962
    963
    964
    965
    966
    967
    968
    969
    970
    971
    972
    973
    974
    975
    976
    977
    978
    979
    980
    981
    982
    983
    984
    985
    986
    987
    988
    989
    990
    991
    992
    993
    994
    995
    996
    997
    998
    999
    1000
    1001
    1002
    1003
    1004
    1005
    1006
    1007
    1008
    1009
    1010
    1011
    1012
    1013
    1014
    1015
    1016
    1017
    1018
    1019
    1020
    1021
    1022
    1023
    1024
    1025
    1026
    1027
    1028
    1029
    1030
    1031
    1032
    1033
    1034
    1035
    1036
    1037
    1038
    1039
    1040
    1041
    1042
    1043
    1044
    1045
    1046
    1047
    1048
    1049
    1050
    1051
    1052
    1053
    1054
    1055
    1056
    1057
    1058
    1059
    1060
    1061
    1062
    1063
    1064
    1065
    1066
    1067
    1068
    1069
    1070
    1071
    1072
    1073
    1074
    1075
    1076
    1077
    1078
    1079
    1080
    1081
    1082
    1083
    1084
    1085
    1086
    1087
    1088
    1089
    1090
    1091
    1092
    1093
    1094
    1095
    1096
    1097
    1098
    1099
    1100
    1101
    1102
    1103
    1104
    1105
    1106
    1107
    1108
    1109
    1110
    1111
    1112
    1113
    1114
    1115
    1116
    1117
    1118
    1119
    1120
    1121
    1122
    1123
    1124
    1125
    1126
    1127
    1128
    1129
    1130
    1131
    1132
    1133
    1134
    1135
    1136
    1137
    1138
    1139
    1140
    1141
    1142
    1143
    1144
    1145
    1146
    1147
    1148
    1149
    1150
    1151
    1152
    1153
    1154
    1155
    1156
    1157
    1158
    1159
    1160
    1161
    1162
    1163
    1164
    1165
    1166
    1167
    1168
    1169
    1170
    1171
    1172
    1173
    1174
    1175
    1176
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.transforms.transforms &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../_static/jquery.js"></script>
    16. <script src="../../../../_static/underscore.js"></script>
    17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../_static/doctools.js"></script>
    19. <script src="../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.training.transforms.transforms</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.training.transforms.transforms</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">import</span> <span class="nn">collections</span>
    86. <span class="kn">import</span> <span class="nn">math</span>
    87. <span class="kn">import</span> <span class="nn">random</span>
    88. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Dict</span>
    89. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageFilter</span><span class="p">,</span> <span class="n">ImageOps</span>
    90. <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span> <span class="k">as</span> <span class="n">transforms</span>
    91. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    92. <span class="kn">import</span> <span class="nn">cv2</span>
    93. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    94. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">get_mosaic_coordinate</span><span class="p">,</span> <span class="n">adjust_box_anns</span><span class="p">,</span> <span class="n">xyxy2cxcywh</span><span class="p">,</span> <span class="n">cxcywh2xyxy</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span>
    95. <span class="n">image_resample</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span>
    96. <span class="n">mask_resample</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">NEAREST</span>
    97. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    98. <span class="k">class</span> <span class="nc">SegmentationTransform</span><span class="p">:</span>
    99. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    100. <span class="k">raise</span> <span class="ne">NotImplementedError</span>
    101. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    102. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;{&quot;</span><span class="p">,</span> <span class="s2">&quot;(&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;}&quot;</span><span class="p">,</span> <span class="s2">&quot;)&quot;</span><span class="p">)</span>
    103. <span class="k">class</span> <span class="nc">SegResize</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
    104. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
    105. <span class="bp">self</span><span class="o">.</span><span class="n">h</span> <span class="o">=</span> <span class="n">h</span>
    106. <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="n">w</span>
    107. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
    108. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    109. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    110. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">h</span><span class="p">),</span> <span class="n">image_resample</span><span class="p">)</span>
    111. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">h</span><span class="p">),</span> <span class="n">mask_resample</span><span class="p">)</span>
    112. <span class="k">return</span> <span class="n">sample</span>
    113. <span class="k">class</span> <span class="nc">SegRandomFlip</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
    114. <span class="sd">&quot;&quot;&quot;</span>
    115. <span class="sd"> Randomly flips the image and mask (synchronously) with probability &#39;prob&#39;.</span>
    116. <span class="sd"> &quot;&quot;&quot;</span>
    117. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">):</span>
    118. <span class="k">assert</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">prob</span> <span class="o">&lt;=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Probability value must be between 0 and 1, found </span><span class="si">{</span><span class="n">prob</span><span class="si">}</span><span class="s2">&quot;</span>
    119. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
    120. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    121. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    122. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    123. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
    124. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">FLIP_LEFT_RIGHT</span><span class="p">)</span>
    125. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">FLIP_LEFT_RIGHT</span><span class="p">)</span>
    126. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    127. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
    128. <span class="k">return</span> <span class="n">sample</span>
    129. <span class="k">class</span> <span class="nc">SegRescale</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
    130. <span class="sd">&quot;&quot;&quot;</span>
    131. <span class="sd"> Rescales the image and mask (synchronously) while preserving aspect ratio.</span>
    132. <span class="sd"> The rescaling can be done according to scale_factor, short_size or long_size.</span>
    133. <span class="sd"> If more than one argument is given, the rescaling mode is determined by this order: scale_factor, then short_size,</span>
    134. <span class="sd"> then long_size.</span>
    135. <span class="sd"> Args:</span>
    136. <span class="sd"> scale_factor: rescaling is done by multiplying input size by scale_factor:</span>
    137. <span class="sd"> out_size = (scale_factor * w, scale_factor * h)</span>
    138. <span class="sd"> short_size: rescaling is done by determining the scale factor by the ratio short_size / min(h, w).</span>
    139. <span class="sd"> long_size: rescaling is done by determining the scale factor by the ratio long_size / max(h, w).</span>
    140. <span class="sd"> &quot;&quot;&quot;</span>
    141. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scale_factor</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">short_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">long_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    142. <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="o">=</span> <span class="n">scale_factor</span>
    143. <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="o">=</span> <span class="n">short_size</span>
    144. <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="o">=</span> <span class="n">long_size</span>
    145. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
    146. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    147. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    148. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    149. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
    150. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    151. <span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span>
    152. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    153. <span class="n">short_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
    154. <span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="o">/</span> <span class="n">short_size</span>
    155. <span class="k">else</span><span class="p">:</span>
    156. <span class="n">long_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
    157. <span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="o">/</span> <span class="n">long_size</span>
    158. <span class="n">out_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span>
    159. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">image_resample</span><span class="p">)</span>
    160. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">mask_resample</span><span class="p">)</span>
    161. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    162. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
    163. <span class="k">return</span> <span class="n">sample</span>
    164. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    165. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    166. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must assign one rescale argument: scale_factor, short_size or long_size&quot;</span><span class="p">)</span>
    167. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
    168. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Scale factor must be a positive number, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    169. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
    170. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Short size must be a positive number, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">short_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    171. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
    172. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Long size must be a positive number, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">long_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    173. <span class="k">class</span> <span class="nc">SegRandomRescale</span><span class="p">:</span>
    174. <span class="sd">&quot;&quot;&quot;</span>
    175. <span class="sd"> Random rescale the image and mask (synchronously) while preserving aspect ratio.</span>
    176. <span class="sd"> Scale factor is randomly picked between scales [min, max]</span>
    177. <span class="sd"> Args:</span>
    178. <span class="sd"> scales: scale range tuple (min, max), if scales is a float range will be defined as (1, scales) if scales &gt; 1,</span>
    179. <span class="sd"> otherwise (scales, 1). must be a positive number.</span>
    180. <span class="sd"> &quot;&quot;&quot;</span>
    181. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scales</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)):</span>
    182. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="n">scales</span>
    183. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
    184. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    185. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    186. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    187. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
    188. <span class="n">scale</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    189. <span class="n">out_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span>
    190. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">image_resample</span><span class="p">)</span>
    191. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">mask_resample</span><span class="p">)</span>
    192. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    193. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
    194. <span class="k">return</span> <span class="n">sample</span>
    195. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    196. <span class="sd">&quot;&quot;&quot;</span>
    197. <span class="sd"> Check the scale values are valid. if order is wrong, flip the order and return the right scale values.</span>
    198. <span class="sd"> &quot;&quot;&quot;</span>
    199. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
    200. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
    201. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    202. <span class="k">else</span><span class="p">:</span>
    203. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">)</span>
    204. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
    205. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;SegRandomRescale scale values must be positive numbers, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    206. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
    207. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    208. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span>
    209. <span class="k">class</span> <span class="nc">SegRandomRotate</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
    210. <span class="sd">&quot;&quot;&quot;</span>
    211. <span class="sd"> Randomly rotates image and mask (synchronously) between &#39;min_deg&#39; and &#39;max_deg&#39;.</span>
    212. <span class="sd"> &quot;&quot;&quot;</span>
    213. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">min_deg</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="n">max_deg</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">fill_mask</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
    214. <span class="bp">self</span><span class="o">.</span><span class="n">min_deg</span> <span class="o">=</span> <span class="n">min_deg</span>
    215. <span class="bp">self</span><span class="o">.</span><span class="n">max_deg</span> <span class="o">=</span> <span class="n">max_deg</span>
    216. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span> <span class="o">=</span> <span class="n">fill_mask</span>
    217. <span class="c1"># grey color in RGB mode</span>
    218. <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">)</span>
    219. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
    220. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    221. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    222. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    223. <span class="n">deg</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">min_deg</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_deg</span><span class="p">)</span>
    224. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">deg</span><span class="p">,</span> <span class="n">resample</span><span class="o">=</span><span class="n">image_resample</span><span class="p">,</span> <span class="n">fillcolor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
    225. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">deg</span><span class="p">,</span> <span class="n">resample</span><span class="o">=</span><span class="n">mask_resample</span><span class="p">,</span> <span class="n">fillcolor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">)</span>
    226. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    227. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
    228. <span class="k">return</span> <span class="n">sample</span>
    229. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    230. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="n">_validate_fill_values_arguments</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
    231. <span class="k">class</span> <span class="nc">SegCropImageAndMask</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
    232. <span class="sd">&quot;&quot;&quot;</span>
    233. <span class="sd"> Crops image and mask (synchronously).</span>
    234. <span class="sd"> In &quot;center&quot; mode a center crop is performed while, in &quot;random&quot; mode the drop will be positioned around</span>
    235. <span class="sd"> random coordinates.</span>
    236. <span class="sd"> &quot;&quot;&quot;</span>
    237. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">],</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    238. <span class="sd">&quot;&quot;&quot;</span>
    239. <span class="sd"> :param crop_size: tuple of (width, height) for the final crop size, if is scalar size is a</span>
    240. <span class="sd"> square (crop_size, crop_size)</span>
    241. <span class="sd"> :param mode: how to choose the center of the crop, &#39;center&#39; for the center of the input image,</span>
    242. <span class="sd"> &#39;random&#39; center the point is chosen randomally</span>
    243. <span class="sd"> &quot;&quot;&quot;</span>
    244. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="n">crop_size</span>
    245. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
    246. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
    247. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    248. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    249. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    250. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
    251. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;random&quot;</span><span class="p">:</span>
    252. <span class="n">x1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    253. <span class="n">y1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    254. <span class="k">else</span><span class="p">:</span>
    255. <span class="n">x1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">))</span>
    256. <span class="n">y1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">))</span>
    257. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">crop</span><span class="p">((</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">y1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
    258. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">crop</span><span class="p">((</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">y1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
    259. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    260. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
    261. <span class="k">return</span> <span class="n">sample</span>
    262. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    263. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;center&quot;</span><span class="p">,</span> <span class="s2">&quot;random&quot;</span><span class="p">]:</span>
    264. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Unsupported mode: found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="si">}</span><span class="s2">, expected: &#39;center&#39; or &#39;random&#39;&quot;</span><span class="p">)</span>
    265. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
    266. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
    267. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
    268. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Crop size must be positive numbers, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    269. <span class="k">class</span> <span class="nc">SegRandomGaussianBlur</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
    270. <span class="sd">&quot;&quot;&quot;</span>
    271. <span class="sd"> Adds random Gaussian Blur to image with probability &#39;prob&#39;.</span>
    272. <span class="sd"> &quot;&quot;&quot;</span>
    273. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">):</span>
    274. <span class="k">assert</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">prob</span> <span class="o">&lt;=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="s2">&quot;Probability value must be between 0 and 1&quot;</span>
    275. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
    276. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    277. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    278. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    279. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
    280. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">ImageFilter</span><span class="o">.</span><span class="n">GaussianBlur</span><span class="p">(</span><span class="n">radius</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()))</span>
    281. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    282. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
    283. <span class="k">return</span> <span class="n">sample</span>
    284. <span class="k">class</span> <span class="nc">SegPadShortToCropSize</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
    285. <span class="sd">&quot;&quot;&quot;</span>
    286. <span class="sd"> Pads image to &#39;crop_size&#39;.</span>
    287. <span class="sd"> Should be called only after &quot;SegRescale&quot; or &quot;SegRandomRescale&quot; in augmentations pipeline.</span>
    288. <span class="sd"> &quot;&quot;&quot;</span>
    289. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">],</span> <span class="n">fill_mask</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
    290. <span class="sd">&quot;&quot;&quot;</span>
    291. <span class="sd"> :param crop_size: tuple of (width, height) for the final crop size, if is scalar size is a</span>
    292. <span class="sd"> square (crop_size, crop_size)</span>
    293. <span class="sd"> :param fill_mask: value to fill mask labels background.</span>
    294. <span class="sd"> :param fill_image: grey value to fill image padded background.</span>
    295. <span class="sd"> &quot;&quot;&quot;</span>
    296. <span class="c1"># CHECK IF CROP SIZE IS A ITERABLE OR SCALAR</span>
    297. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="n">crop_size</span>
    298. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span> <span class="o">=</span> <span class="n">fill_mask</span>
    299. <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">)</span> <span class="k">else</span> <span class="n">fill_image</span>
    300. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
    301. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    302. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
    303. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
    304. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
    305. <span class="c1"># pad images from center symmetrically</span>
    306. <span class="k">if</span> <span class="n">w</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="n">h</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
    307. <span class="n">padh</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">h</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">h</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">else</span> <span class="mi">0</span>
    308. <span class="n">pad_top</span><span class="p">,</span> <span class="n">pad_bottom</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">padh</span><span class="p">),</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">padh</span><span class="p">)</span>
    309. <span class="n">padw</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">w</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">w</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">else</span> <span class="mi">0</span>
    310. <span class="n">pad_left</span><span class="p">,</span> <span class="n">pad_right</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">padw</span><span class="p">),</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">padw</span><span class="p">)</span>
    311. <span class="n">image</span> <span class="o">=</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">border</span><span class="o">=</span><span class="p">(</span><span class="n">pad_left</span><span class="p">,</span> <span class="n">pad_top</span><span class="p">,</span> <span class="n">pad_right</span><span class="p">,</span> <span class="n">pad_bottom</span><span class="p">),</span> <span class="n">fill</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
    312. <span class="n">mask</span> <span class="o">=</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">border</span><span class="o">=</span><span class="p">(</span><span class="n">pad_left</span><span class="p">,</span> <span class="n">pad_top</span><span class="p">,</span> <span class="n">pad_right</span><span class="p">,</span> <span class="n">pad_bottom</span><span class="p">),</span> <span class="n">fill</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">)</span>
    313. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    314. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
    315. <span class="k">return</span> <span class="n">sample</span>
    316. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    317. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
    318. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
    319. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
    320. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Crop size must be positive numbers, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    321. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="n">_validate_fill_values_arguments</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
    322. <span class="k">class</span> <span class="nc">SegColorJitter</span><span class="p">(</span><span class="n">transforms</span><span class="o">.</span><span class="n">ColorJitter</span><span class="p">):</span>
    323. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
    324. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">SegColorJitter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">])</span>
    325. <span class="k">return</span> <span class="n">sample</span>
    326. <span class="k">def</span> <span class="nf">_validate_fill_values_arguments</span><span class="p">(</span><span class="n">fill_mask</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]):</span>
    327. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
    328. <span class="c1"># If fill_image is single value, turn to grey color in RGB mode.</span>
    329. <span class="n">fill_image</span> <span class="o">=</span> <span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">)</span>
    330. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
    331. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;fill_image must be an RGB tuple of size equal to 3, found: </span><span class="si">{</span><span class="n">fill_image</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    332. <span class="c1"># assert values are integers</span>
    333. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_mask</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">or</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">fill_image</span><span class="p">):</span>
    334. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Fill value must be integers,&quot;</span> <span class="sa">f</span><span class="s2">&quot; found: fill_image = </span><span class="si">{</span><span class="n">fill_image</span><span class="si">}</span><span class="s2">, fill_mask = </span><span class="si">{</span><span class="n">fill_mask</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    335. <span class="c1"># assert values in range 0-255</span>
    336. <span class="k">if</span> <span class="nb">min</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">max</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">255</span> <span class="ow">or</span> <span class="n">fill_mask</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">fill_mask</span> <span class="o">&gt;</span> <span class="mi">255</span><span class="p">:</span>
    337. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Fill value must be a value from 0 to 255,&quot;</span> <span class="sa">f</span><span class="s2">&quot; found: fill_image = </span><span class="si">{</span><span class="n">fill_image</span><span class="si">}</span><span class="s2">, fill_mask = </span><span class="si">{</span><span class="n">fill_mask</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    338. <span class="k">return</span> <span class="n">fill_mask</span><span class="p">,</span> <span class="n">fill_image</span>
    339. <span class="k">class</span> <span class="nc">DetectionTransform</span><span class="p">:</span>
    340. <span class="sd">&quot;&quot;&quot;</span>
    341. <span class="sd"> Detection transform base class.</span>
    342. <span class="sd"> Complex transforms that require extra data loading can use the the additional_samples_count attribute in a</span>
    343. <span class="sd"> similar fashion to what&#39;s been done in COCODetectionDataset:</span>
    344. <span class="sd"> self._load_additional_inputs_for_transform(sample, transform)</span>
    345. <span class="sd"> # after the above call, sample[&quot;additional_samples&quot;] holds a list of additional inputs and targets.</span>
    346. <span class="sd"> sample = transform(sample)</span>
    347. <span class="sd"> Attributes:</span>
    348. <span class="sd"> additional_samples_count: (int) additional samples to be loaded.</span>
    349. <span class="sd"> non_empty_targets: (bool) whether the additianl targets can have empty targets or not.</span>
    350. <span class="sd"> &quot;&quot;&quot;</span>
    351. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">additional_samples_count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">non_empty_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    352. <span class="bp">self</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">additional_samples_count</span>
    353. <span class="bp">self</span><span class="o">.</span><span class="n">non_empty_targets</span> <span class="o">=</span> <span class="n">non_empty_targets</span>
    354. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">list</span><span class="p">]):</span>
    355. <span class="k">raise</span> <span class="ne">NotImplementedError</span>
    356. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    357. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;{&quot;</span><span class="p">,</span> <span class="s2">&quot;(&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;}&quot;</span><span class="p">,</span> <span class="s2">&quot;)&quot;</span><span class="p">)</span>
    358. <div class="viewcode-block" id="DetectionMosaic"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionMosaic">[docs]</a><span class="k">class</span> <span class="nc">DetectionMosaic</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
    359. <span class="sd">&quot;&quot;&quot;</span>
    360. <span class="sd"> DetectionMosaic detection transform</span>
    361. <span class="sd"> Attributes:</span>
    362. <span class="sd"> input_dim: (tuple) input dimension.</span>
    363. <span class="sd"> prob: (float) probability of applying mosaic.</span>
    364. <span class="sd"> enable_mosaic: (bool) whether to apply mosaic at all (regardless of prob) (default=True).</span>
    365. <span class="sd"> &quot;&quot;&quot;</span>
    366. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">enable_mosaic</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
    367. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionMosaic</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
    368. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
    369. <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_dim</span>
    370. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mosaic</span> <span class="o">=</span> <span class="n">enable_mosaic</span>
    371. <div class="viewcode-block" id="DetectionMosaic.close"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionMosaic.close">[docs]</a> <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    372. <span class="bp">self</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="o">=</span> <span class="mi">0</span>
    373. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mosaic</span> <span class="o">=</span> <span class="kc">False</span></div>
    374. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">list</span><span class="p">]):</span>
    375. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_mosaic</span> <span class="ow">and</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
    376. <span class="n">mosaic_labels</span> <span class="o">=</span> <span class="p">[]</span>
    377. <span class="n">mosaic_labels_seg</span> <span class="o">=</span> <span class="p">[]</span>
    378. <span class="n">input_h</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    379. <span class="c1"># yc, xc = s, s # mosaic center x, y</span>
    380. <span class="n">yc</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="mf">1.5</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">))</span>
    381. <span class="n">xc</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="mf">1.5</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">))</span>
    382. <span class="c1"># 3 additional samples, total of 4</span>
    383. <span class="n">all_samples</span> <span class="o">=</span> <span class="p">[</span><span class="n">sample</span><span class="p">]</span> <span class="o">+</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span>
    384. <span class="k">for</span> <span class="n">i_mosaic</span><span class="p">,</span> <span class="n">mosaic_sample</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">all_samples</span><span class="p">):</span>
    385. <span class="n">img</span><span class="p">,</span> <span class="n">_labels</span> <span class="o">=</span> <span class="n">mosaic_sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">mosaic_sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
    386. <span class="n">_labels_seg</span> <span class="o">=</span> <span class="n">mosaic_sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;target_seg&quot;</span><span class="p">)</span>
    387. <span class="n">h0</span><span class="p">,</span> <span class="n">w0</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="c1"># orig hw</span>
    388. <span class="n">scale</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">input_h</span> <span class="o">/</span> <span class="n">h0</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">*</span> <span class="n">input_w</span> <span class="o">/</span> <span class="n">w0</span><span class="p">)</span>
    389. <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">w0</span> <span class="o">*</span> <span class="n">scale</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">h0</span> <span class="o">*</span> <span class="n">scale</span><span class="p">)),</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">)</span>
    390. <span class="c1"># generate output mosaic image</span>
    391. <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
    392. <span class="k">if</span> <span class="n">i_mosaic</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    393. <span class="n">mosaic_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">input_h</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">c</span><span class="p">),</span> <span class="mi">114</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
    394. <span class="c1"># suffix l means large image, while s means small image in mosaic aug.</span>
    395. <span class="p">(</span><span class="n">l_x1</span><span class="p">,</span> <span class="n">l_y1</span><span class="p">,</span> <span class="n">l_x2</span><span class="p">,</span> <span class="n">l_y2</span><span class="p">),</span> <span class="p">(</span><span class="n">s_x1</span><span class="p">,</span> <span class="n">s_y1</span><span class="p">,</span> <span class="n">s_x2</span><span class="p">,</span> <span class="n">s_y2</span><span class="p">)</span> <span class="o">=</span> <span class="n">get_mosaic_coordinate</span><span class="p">(</span><span class="n">i_mosaic</span><span class="p">,</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">input_w</span><span class="p">)</span>
    396. <span class="n">mosaic_img</span><span class="p">[</span><span class="n">l_y1</span><span class="p">:</span><span class="n">l_y2</span><span class="p">,</span> <span class="n">l_x1</span><span class="p">:</span><span class="n">l_x2</span><span class="p">]</span> <span class="o">=</span> <span class="n">img</span><span class="p">[</span><span class="n">s_y1</span><span class="p">:</span><span class="n">s_y2</span><span class="p">,</span> <span class="n">s_x1</span><span class="p">:</span><span class="n">s_x2</span><span class="p">]</span>
    397. <span class="n">padw</span><span class="p">,</span> <span class="n">padh</span> <span class="o">=</span> <span class="n">l_x1</span> <span class="o">-</span> <span class="n">s_x1</span><span class="p">,</span> <span class="n">l_y1</span> <span class="o">-</span> <span class="n">s_y1</span>
    398. <span class="n">labels</span> <span class="o">=</span> <span class="n">_labels</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    399. <span class="c1"># Normalized xywh to pixel xyxy format</span>
    400. <span class="k">if</span> <span class="n">_labels</span><span class="o">.</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    401. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">padw</span>
    402. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">padh</span>
    403. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">padw</span>
    404. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">+</span> <span class="n">padh</span>
    405. <span class="n">mosaic_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    406. <span class="k">if</span> <span class="n">_labels_seg</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    407. <span class="n">labels_seg</span> <span class="o">=</span> <span class="n">_labels_seg</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    408. <span class="k">if</span> <span class="n">_labels</span><span class="o">.</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    409. <span class="n">labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">padw</span>
    410. <span class="n">labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">padh</span>
    411. <span class="n">mosaic_labels_seg</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">labels_seg</span><span class="p">)</span>
    412. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">):</span>
    413. <span class="n">mosaic_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    414. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
    415. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span>
    416. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">])</span>
    417. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">])</span>
    418. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">):</span>
    419. <span class="n">mosaic_labels_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    420. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">])</span>
    421. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">])</span>
    422. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mosaic_img</span>
    423. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mosaic_labels</span>
    424. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;info&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">mosaic_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">mosaic_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    425. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">):</span>
    426. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target_seg&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mosaic_labels_seg</span>
    427. <span class="k">return</span> <span class="n">sample</span></div>
    428. <div class="viewcode-block" id="DetectionRandomAffine"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionRandomAffine">[docs]</a><span class="k">class</span> <span class="nc">DetectionRandomAffine</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
    429. <span class="sd">&quot;&quot;&quot;</span>
    430. <span class="sd"> DetectionRandomAffine detection transform</span>
    431. <span class="sd"> Attributes:</span>
    432. <span class="sd"> target_size: (tuple) desired output shape.</span>
    433. <span class="sd"> degrees: (Union[tuple, float]) degrees for random rotation, when float the random values are drawn uniformly</span>
    434. <span class="sd"> from (-degrees, degrees)</span>
    435. <span class="sd"> translate: (Union[tuple, float]) translate size (in pixels) for random translation, when float the random values</span>
    436. <span class="sd"> are drawn uniformly from (-translate, translate)</span>
    437. <span class="sd"> scales: (Union[tuple, float]) values for random rescale, when float the random values are drawn uniformly</span>
    438. <span class="sd"> from (0.1-scales, 0.1+scales)</span>
    439. <span class="sd"> shear: (Union[tuple, float]) degrees for random shear, when float the random values are drawn uniformly</span>
    440. <span class="sd"> from (shear, shear)</span>
    441. <span class="sd"> enable: (bool) whether to apply the below transform at all.</span>
    442. <span class="sd"> filter_box_candidates: (bool) whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio (default=False).</span>
    443. <span class="sd"> wh_thr: (float) edge size threshold when filter_box_candidates = True. Bounding oxes with edges smaller</span>
    444. <span class="sd"> then this values will be filtered out. (default=2)</span>
    445. <span class="sd"> ar_thr: (float) aspect ratio threshold filter_box_candidates = True. Bounding boxes with aspect ratio larger</span>
    446. <span class="sd"> then this values will be filtered out. (default=20)</span>
    447. <span class="sd"> area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.</span>
    448. <span class="sd"> Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)</span>
    449. <span class="sd"> &quot;&quot;&quot;</span>
    450. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    451. <span class="bp">self</span><span class="p">,</span> <span class="n">degrees</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">translate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">scales</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">shear</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">target_size</span><span class="o">=</span><span class="p">(</span><span class="mi">640</span><span class="p">,</span> <span class="mi">640</span><span class="p">),</span> <span class="n">filter_box_candidates</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">wh_thr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">ar_thr</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">area_thr</span><span class="o">=</span><span class="mf">0.1</span>
    452. <span class="p">):</span>
    453. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionRandomAffine</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    454. <span class="bp">self</span><span class="o">.</span><span class="n">degrees</span> <span class="o">=</span> <span class="n">degrees</span>
    455. <span class="bp">self</span><span class="o">.</span><span class="n">translate</span> <span class="o">=</span> <span class="n">translate</span>
    456. <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scales</span>
    457. <span class="bp">self</span><span class="o">.</span><span class="n">shear</span> <span class="o">=</span> <span class="n">shear</span>
    458. <span class="bp">self</span><span class="o">.</span><span class="n">target_size</span> <span class="o">=</span> <span class="n">target_size</span>
    459. <span class="bp">self</span><span class="o">.</span><span class="n">enable</span> <span class="o">=</span> <span class="kc">True</span>
    460. <span class="bp">self</span><span class="o">.</span><span class="n">filter_box_candidates</span> <span class="o">=</span> <span class="n">filter_box_candidates</span>
    461. <span class="bp">self</span><span class="o">.</span><span class="n">wh_thr</span> <span class="o">=</span> <span class="n">wh_thr</span>
    462. <span class="bp">self</span><span class="o">.</span><span class="n">ar_thr</span> <span class="o">=</span> <span class="n">ar_thr</span>
    463. <span class="bp">self</span><span class="o">.</span><span class="n">area_thr</span> <span class="o">=</span> <span class="n">area_thr</span>
    464. <div class="viewcode-block" id="DetectionRandomAffine.close"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionRandomAffine.close">[docs]</a> <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    465. <span class="bp">self</span><span class="o">.</span><span class="n">enable</span> <span class="o">=</span> <span class="kc">False</span></div>
    466. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    467. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable</span><span class="p">:</span>
    468. <span class="n">img</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">random_affine</span><span class="p">(</span>
    469. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span>
    470. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span>
    471. <span class="n">sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;target_seg&quot;</span><span class="p">),</span>
    472. <span class="n">target_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_size</span><span class="p">,</span>
    473. <span class="n">degrees</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">degrees</span><span class="p">,</span>
    474. <span class="n">translate</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">translate</span><span class="p">,</span>
    475. <span class="n">scales</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="p">,</span>
    476. <span class="n">shear</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shear</span><span class="p">,</span>
    477. <span class="n">filter_box_candidates</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">filter_box_candidates</span><span class="p">,</span>
    478. <span class="n">wh_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">wh_thr</span><span class="p">,</span>
    479. <span class="n">area_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">area_thr</span><span class="p">,</span>
    480. <span class="n">ar_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ar_thr</span><span class="p">,</span>
    481. <span class="p">)</span>
    482. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">img</span>
    483. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">target</span>
    484. <span class="k">return</span> <span class="n">sample</span></div>
    485. <span class="k">class</span> <span class="nc">DetectionMixup</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
    486. <span class="sd">&quot;&quot;&quot;</span>
    487. <span class="sd"> Mixup detection transform</span>
    488. <span class="sd"> Attributes:</span>
    489. <span class="sd"> input_dim: (tuple) input dimension.</span>
    490. <span class="sd"> mixup_scale: (tuple) scale range for the additional loaded image for mixup.</span>
    491. <span class="sd"> prob: (float) probability of applying mixup.</span>
    492. <span class="sd"> enable_mixup: (bool) whether to apply mixup at all (regardless of prob) (default=True).</span>
    493. <span class="sd"> flip_prob: (float) prbability to apply horizontal flip to the additional sample.</span>
    494. <span class="sd"> &quot;&quot;&quot;</span>
    495. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">,</span> <span class="n">mixup_scale</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">enable_mixup</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">flip_prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
    496. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionMixup</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">non_empty_targets</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    497. <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_dim</span>
    498. <span class="bp">self</span><span class="o">.</span><span class="n">mixup_scale</span> <span class="o">=</span> <span class="n">mixup_scale</span>
    499. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
    500. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mixup</span> <span class="o">=</span> <span class="n">enable_mixup</span>
    501. <span class="bp">self</span><span class="o">.</span><span class="n">flip_prob</span> <span class="o">=</span> <span class="n">flip_prob</span>
    502. <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    503. <span class="bp">self</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="o">=</span> <span class="mi">0</span>
    504. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mixup</span> <span class="o">=</span> <span class="kc">False</span>
    505. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    506. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_mixup</span> <span class="ow">and</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
    507. <span class="n">origin_img</span><span class="p">,</span> <span class="n">origin_labels</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
    508. <span class="n">cp_sample</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
    509. <span class="n">img</span><span class="p">,</span> <span class="n">cp_labels</span> <span class="o">=</span> <span class="n">cp_sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">cp_sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
    510. <span class="n">cp_boxes</span> <span class="o">=</span> <span class="n">cp_labels</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>
    511. <span class="n">img</span><span class="p">,</span> <span class="n">cp_boxes</span> <span class="o">=</span> <span class="n">_mirror</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">cp_boxes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">flip_prob</span><span class="p">)</span>
    512. <span class="c1"># PLUG IN TARGET THE FLIPPED BOXES</span>
    513. <span class="n">cp_labels</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">cp_boxes</span>
    514. <span class="n">jit_factor</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_scale</span><span class="p">)</span>
    515. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
    516. <span class="n">cp_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="mi">114</span>
    517. <span class="k">else</span><span class="p">:</span>
    518. <span class="n">cp_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="mi">114</span>
    519. <span class="n">cp_scale_ratio</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    520. <span class="n">resized_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
    521. <span class="n">img</span><span class="p">,</span>
    522. <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">)),</span>
    523. <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">,</span>
    524. <span class="p">)</span>
    525. <span class="n">cp_img</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">),</span> <span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">)]</span> <span class="o">=</span> <span class="n">resized_img</span>
    526. <span class="n">cp_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
    527. <span class="n">cp_img</span><span class="p">,</span>
    528. <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">jit_factor</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">cp_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">jit_factor</span><span class="p">)),</span>
    529. <span class="p">)</span>
    530. <span class="n">cp_scale_ratio</span> <span class="o">*=</span> <span class="n">jit_factor</span>
    531. <span class="n">origin_h</span><span class="p">,</span> <span class="n">origin_w</span> <span class="o">=</span> <span class="n">cp_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
    532. <span class="n">target_h</span><span class="p">,</span> <span class="n">target_w</span> <span class="o">=</span> <span class="n">origin_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
    533. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">max</span><span class="p">(</span><span class="n">origin_h</span><span class="p">,</span> <span class="n">target_h</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">origin_w</span><span class="p">,</span> <span class="n">target_w</span><span class="p">),</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
    534. <span class="n">padded_img</span><span class="p">[:</span><span class="n">origin_h</span><span class="p">,</span> <span class="p">:</span><span class="n">origin_w</span><span class="p">]</span> <span class="o">=</span> <span class="n">cp_img</span>
    535. <span class="n">x_offset</span><span class="p">,</span> <span class="n">y_offset</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
    536. <span class="k">if</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target_h</span><span class="p">:</span>
    537. <span class="n">y_offset</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">target_h</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
    538. <span class="k">if</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target_w</span><span class="p">:</span>
    539. <span class="n">x_offset</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">target_w</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
    540. <span class="n">padded_cropped_img</span> <span class="o">=</span> <span class="n">padded_img</span><span class="p">[</span><span class="n">y_offset</span> <span class="p">:</span> <span class="n">y_offset</span> <span class="o">+</span> <span class="n">target_h</span><span class="p">,</span> <span class="n">x_offset</span> <span class="p">:</span> <span class="n">x_offset</span> <span class="o">+</span> <span class="n">target_w</span><span class="p">]</span>
    541. <span class="n">cp_bboxes_origin_np</span> <span class="o">=</span> <span class="n">adjust_box_anns</span><span class="p">(</span><span class="n">cp_labels</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">(),</span> <span class="n">cp_scale_ratio</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">origin_w</span><span class="p">,</span> <span class="n">origin_h</span><span class="p">)</span>
    542. <span class="n">cp_bboxes_transformed_np</span> <span class="o">=</span> <span class="n">cp_bboxes_origin_np</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    543. <span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">x_offset</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">target_w</span><span class="p">)</span>
    544. <span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">y_offset</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">target_h</span><span class="p">)</span>
    545. <span class="n">cls_labels</span> <span class="o">=</span> <span class="n">cp_labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    546. <span class="n">box_labels</span> <span class="o">=</span> <span class="n">cp_bboxes_transformed_np</span>
    547. <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">box_labels</span><span class="p">,</span> <span class="n">cls_labels</span><span class="p">))</span>
    548. <span class="n">origin_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">((</span><span class="n">origin_labels</span><span class="p">,</span> <span class="n">labels</span><span class="p">))</span>
    549. <span class="n">origin_img</span> <span class="o">=</span> <span class="n">origin_img</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    550. <span class="n">origin_img</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">origin_img</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">padded_cropped_img</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    551. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">origin_img</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">),</span> <span class="n">origin_labels</span>
    552. <span class="k">return</span> <span class="n">sample</span>
    553. <div class="viewcode-block" id="DetectionPaddedRescale"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionPaddedRescale">[docs]</a><span class="k">class</span> <span class="nc">DetectionPaddedRescale</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
    554. <span class="sd">&quot;&quot;&quot;</span>
    555. <span class="sd"> Preprocessing transform to be applied last of all transforms for validation.</span>
    556. <span class="sd"> Image- Rescales and pads to self.input_dim.</span>
    557. <span class="sd"> Targets- pads targets to max_targets, moves the class label to first index, converts boxes format- xyxy -&gt; cxcywh.</span>
    558. <span class="sd"> Attributes:</span>
    559. <span class="sd"> input_dim: (tuple) final input dimension (default=(640,640))</span>
    560. <span class="sd"> swap: image axis&#39;s to be rearranged.</span>
    561. <span class="sd"> &quot;&quot;&quot;</span>
    562. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">,</span> <span class="n">swap</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">max_targets</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">pad_value</span><span class="o">=</span><span class="mi">114</span><span class="p">):</span>
    563. <span class="bp">self</span><span class="o">.</span><span class="n">swap</span> <span class="o">=</span> <span class="n">swap</span>
    564. <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_dim</span>
    565. <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span> <span class="o">=</span> <span class="n">max_targets</span>
    566. <span class="bp">self</span><span class="o">.</span><span class="n">pad_value</span> <span class="o">=</span> <span class="n">pad_value</span>
    567. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">]):</span>
    568. <span class="n">img</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">)</span>
    569. <span class="n">img</span><span class="p">,</span> <span class="n">r</span> <span class="o">=</span> <span class="n">rescale_and_pad_to_size</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">swap</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad_value</span><span class="p">)</span>
    570. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">img</span>
    571. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rescale_target</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">r</span><span class="p">)</span>
    572. <span class="k">if</span> <span class="n">crowd_targets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    573. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rescale_target</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">,</span> <span class="n">r</span><span class="p">)</span>
    574. <span class="k">return</span> <span class="n">sample</span>
    575. <span class="k">def</span> <span class="nf">_rescale_target</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">r</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
    576. <span class="sd">&quot;&quot;&quot;SegRescale the target according to a coefficient used to rescale the image.</span>
    577. <span class="sd"> This is done to have images and targets at the same scale.</span>
    578. <span class="sd"> :param targets: Targets to rescale, shape (batch_size, 6)</span>
    579. <span class="sd"> :param r: SegRescale coefficient that was applied to the image</span>
    580. <span class="sd"> :return: Rescaled targets, shape (batch_size, 6)</span>
    581. <span class="sd"> &quot;&quot;&quot;</span>
    582. <span class="n">targets</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    583. <span class="n">boxes</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span>
    584. <span class="n">boxes</span> <span class="o">=</span> <span class="n">xyxy2cxcywh</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
    585. <span class="n">boxes</span> <span class="o">*=</span> <span class="n">r</span>
    586. <span class="n">boxes</span> <span class="o">=</span> <span class="n">cxcywh2xyxy</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
    587. <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">boxes</span><span class="p">,</span> <span class="n">labels</span><span class="p">[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span></div>
    588. <span class="k">class</span> <span class="nc">DetectionHorizontalFlip</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
    589. <span class="sd">&quot;&quot;&quot;</span>
    590. <span class="sd"> Horizontal Flip for Detection</span>
    591. <span class="sd"> Attributes:</span>
    592. <span class="sd"> prob: float: probability of applying horizontal flip</span>
    593. <span class="sd"> max_targets: int: max objects in single image, padding target to this size in case of empty image.</span>
    594. <span class="sd"> &quot;&quot;&quot;</span>
    595. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">,</span> <span class="n">max_targets</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">120</span><span class="p">):</span>
    596. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionHorizontalFlip</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    597. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
    598. <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span> <span class="o">=</span> <span class="n">max_targets</span>
    599. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
    600. <span class="n">image</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
    601. <span class="n">boxes</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>
    602. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    603. <span class="n">targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    604. <span class="n">boxes</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>
    605. <span class="n">image</span><span class="p">,</span> <span class="n">boxes</span> <span class="o">=</span> <span class="n">_mirror</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">boxes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">)</span>
    606. <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span>
    607. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span>
    608. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
    609. <span class="k">return</span> <span class="n">sample</span>
    610. <div class="viewcode-block" id="DetectionHSV"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionHSV">[docs]</a><span class="k">class</span> <span class="nc">DetectionHSV</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
    611. <span class="sd">&quot;&quot;&quot;</span>
    612. <span class="sd"> Detection HSV transform.</span>
    613. <span class="sd"> &quot;&quot;&quot;</span>
    614. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">hgain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">sgain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">vgain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)):</span>
    615. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionHSV</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    616. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
    617. <span class="bp">self</span><span class="o">.</span><span class="n">hgain</span> <span class="o">=</span> <span class="n">hgain</span>
    618. <span class="bp">self</span><span class="o">.</span><span class="n">sgain</span> <span class="o">=</span> <span class="n">sgain</span>
    619. <span class="bp">self</span><span class="o">.</span><span class="n">vgain</span> <span class="o">=</span> <span class="n">vgain</span>
    620. <span class="bp">self</span><span class="o">.</span><span class="n">bgr_channels</span> <span class="o">=</span> <span class="n">bgr_channels</span>
    621. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
    622. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
    623. <span class="n">augment_hsv</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">hgain</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sgain</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgain</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bgr_channels</span><span class="p">)</span>
    624. <span class="k">return</span> <span class="n">sample</span></div>
    625. <div class="viewcode-block" id="DetectionTargetsFormatTransform"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionTargetsFormatTransform">[docs]</a><span class="k">class</span> <span class="nc">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
    626. <span class="sd">&quot;&quot;&quot;</span>
    627. <span class="sd"> Detection targets format transform</span>
    628. <span class="sd"> Converts targets in input_format to output_format.</span>
    629. <span class="sd"> Attributes:</span>
    630. <span class="sd"> input_format: DetectionTargetsFormat: input target format</span>
    631. <span class="sd"> output_format: DetectionTargetsFormat: output target format</span>
    632. <span class="sd"> min_bbox_edge_size: int: bboxes with edge size lower then this values will be removed.</span>
    633. <span class="sd"> max_targets: int: max objects in single image, padding target to this size.</span>
    634. <span class="sd"> &quot;&quot;&quot;</span>
    635. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    636. <span class="bp">self</span><span class="p">,</span>
    637. <span class="n">input_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span><span class="p">,</span>
    638. <span class="n">output_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">LABEL_CXCYWH</span><span class="p">,</span>
    639. <span class="n">min_bbox_edge_size</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
    640. <span class="n">max_targets</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">120</span><span class="p">,</span>
    641. <span class="p">):</span>
    642. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionTargetsFormatTransform</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    643. <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span> <span class="o">=</span> <span class="n">input_format</span>
    644. <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span> <span class="o">=</span> <span class="n">output_format</span>
    645. <span class="bp">self</span><span class="o">.</span><span class="n">min_bbox_edge_size</span> <span class="o">=</span> <span class="n">min_bbox_edge_size</span>
    646. <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span> <span class="o">=</span> <span class="n">max_targets</span>
    647. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
    648. <span class="n">normalized_input</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span><span class="o">.</span><span class="n">value</span>
    649. <span class="n">normalized_output</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span><span class="o">.</span><span class="n">value</span>
    650. <span class="n">normalize</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">normalized_input</span> <span class="ow">and</span> <span class="n">normalized_output</span>
    651. <span class="n">denormalize</span> <span class="o">=</span> <span class="n">normalized_input</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">normalized_output</span>
    652. <span class="n">label_first_in_input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span>
    653. <span class="n">label_first_in_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span>
    654. <span class="n">input_xyxy_format</span> <span class="o">=</span> <span class="s2">&quot;XYXY&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span><span class="o">.</span><span class="n">value</span>
    655. <span class="n">output_xyxy_format</span> <span class="o">=</span> <span class="s2">&quot;XYXY&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span><span class="o">.</span><span class="n">value</span>
    656. <span class="n">convert2xyxy</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">input_xyxy_format</span> <span class="ow">and</span> <span class="n">output_xyxy_format</span>
    657. <span class="n">convert2cxcy</span> <span class="o">=</span> <span class="n">input_xyxy_format</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">output_xyxy_format</span>
    658. <span class="n">image</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">)</span>
    659. <span class="n">_</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span>
    660. <span class="k">def</span> <span class="nf">_format_target</span><span class="p">(</span><span class="n">targets_in</span><span class="p">):</span>
    661. <span class="k">if</span> <span class="n">label_first_in_input</span><span class="p">:</span>
    662. <span class="n">labels</span><span class="p">,</span> <span class="n">boxes</span> <span class="o">=</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:]</span>
    663. <span class="k">else</span><span class="p">:</span>
    664. <span class="n">boxes</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span>
    665. <span class="k">if</span> <span class="n">convert2cxcy</span><span class="p">:</span>
    666. <span class="n">boxes</span> <span class="o">=</span> <span class="n">xyxy2cxcywh</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
    667. <span class="k">elif</span> <span class="n">convert2xyxy</span><span class="p">:</span>
    668. <span class="n">boxes</span> <span class="o">=</span> <span class="n">cxcywh2xyxy</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
    669. <span class="k">if</span> <span class="n">normalize</span><span class="p">:</span>
    670. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">w</span>
    671. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">h</span>
    672. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="n">w</span>
    673. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="n">h</span>
    674. <span class="k">elif</span> <span class="n">denormalize</span><span class="p">:</span>
    675. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span>
    676. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span>
    677. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span>
    678. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span>
    679. <span class="n">min_bbox_edge_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_bbox_edge_size</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span> <span class="k">if</span> <span class="n">normalized_output</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_bbox_edge_size</span>
    680. <span class="n">cxcywh_boxes</span> <span class="o">=</span> <span class="n">boxes</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">output_xyxy_format</span> <span class="k">else</span> <span class="n">xyxy2cxcywh</span><span class="p">(</span><span class="n">boxes</span><span class="o">.</span><span class="n">copy</span><span class="p">())</span>
    681. <span class="n">mask_b</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">cxcywh_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">cxcywh_boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">])</span> <span class="o">&gt;</span> <span class="n">min_bbox_edge_size</span>
    682. <span class="n">boxes_t</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[</span><span class="n">mask_b</span><span class="p">]</span>
    683. <span class="n">labels_t</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">mask_b</span><span class="p">]</span>
    684. <span class="n">labels_t</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">labels_t</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    685. <span class="n">targets_t</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">labels_t</span><span class="p">,</span> <span class="n">boxes_t</span><span class="p">))</span> <span class="k">if</span> <span class="n">label_first_in_output</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">boxes_t</span><span class="p">,</span> <span class="n">labels_t</span><span class="p">))</span>
    686. <span class="n">padded_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
    687. <span class="n">padded_targets</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">targets_t</span><span class="p">))[:</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">]]</span> <span class="o">=</span> <span class="n">targets_t</span><span class="p">[:</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">]</span>
    688. <span class="n">padded_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">padded_targets</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    689. <span class="k">return</span> <span class="n">padded_targets</span>
    690. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">_format_target</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
    691. <span class="k">if</span> <span class="n">crowd_targets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    692. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">_format_target</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)</span>
    693. <span class="k">return</span> <span class="n">sample</span></div>
    694. <span class="k">def</span> <span class="nf">get_aug_params</span><span class="p">(</span><span class="n">value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">center</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
    695. <span class="sd">&quot;&quot;&quot;</span>
    696. <span class="sd"> Generates a random value for augmentations as described below</span>
    697. <span class="sd"> :param value: Union[tuple, float] defines the range of values for generation. Wen tuple-</span>
    698. <span class="sd"> drawn uniformly between (value[0], value[1]), and (center - value, center + value) when float</span>
    699. <span class="sd"> :param center: float, defines center to subtract when value is float.</span>
    700. <span class="sd"> :return: generated value</span>
    701. <span class="sd"> &quot;&quot;&quot;</span>
    702. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
    703. <span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">center</span> <span class="o">-</span> <span class="n">value</span><span class="p">,</span> <span class="n">center</span> <span class="o">+</span> <span class="n">value</span><span class="p">)</span>
    704. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
    705. <span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">value</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">value</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    706. <span class="k">else</span><span class="p">:</span>
    707. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
    708. <span class="s2">&quot;Affine params should be either a sequence containing two values</span><span class="se">\</span>
    709. <span class="s2"> or single float values. Got </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
    710. <span class="n">value</span>
    711. <span class="p">)</span>
    712. <span class="p">)</span>
    713. <span class="k">def</span> <span class="nf">get_affine_matrix</span><span class="p">(</span>
    714. <span class="n">target_size</span><span class="p">,</span>
    715. <span class="n">degrees</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
    716. <span class="n">translate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
    717. <span class="n">scales</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
    718. <span class="n">shear</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
    719. <span class="p">):</span>
    720. <span class="sd">&quot;&quot;&quot;</span>
    721. <span class="sd"> Returns a random affine transform matrix.</span>
    722. <span class="sd"> :param target_size: (tuple) desired output shape.</span>
    723. <span class="sd"> :param degrees: (Union[tuple, float]) degrees for random rotation, when float the random values are drawn uniformly</span>
    724. <span class="sd"> from (-degrees, degrees)</span>
    725. <span class="sd"> :param translate: (Union[tuple, float]) translate size (in pixels) for random translation, when float the random values</span>
    726. <span class="sd"> are drawn uniformly from (-translate, translate)</span>
    727. <span class="sd"> :param scales: (Union[tuple, float]) values for random rescale, when float the random values are drawn uniformly</span>
    728. <span class="sd"> from (0.1-scales, 0.1+scales)</span>
    729. <span class="sd"> :param shear: (Union[tuple, float]) degrees for random shear, when float the random values are drawn uniformly</span>
    730. <span class="sd"> from (shear, shear)</span>
    731. <span class="sd"> :return: affine_transform_matrix, drawn_scale</span>
    732. <span class="sd"> &quot;&quot;&quot;</span>
    733. <span class="n">twidth</span><span class="p">,</span> <span class="n">theight</span> <span class="o">=</span> <span class="n">target_size</span>
    734. <span class="c1"># Rotation and Scale</span>
    735. <span class="n">angle</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">degrees</span><span class="p">)</span>
    736. <span class="n">scale</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">scales</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
    737. <span class="k">if</span> <span class="n">scale</span> <span class="o">&lt;=</span> <span class="mf">0.0</span><span class="p">:</span>
    738. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Argument scale should be positive&quot;</span><span class="p">)</span>
    739. <span class="n">R</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">getRotationMatrix2D</span><span class="p">(</span><span class="n">angle</span><span class="o">=</span><span class="n">angle</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
    740. <span class="n">M</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
    741. <span class="c1"># Shear</span>
    742. <span class="n">shear_x</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">tan</span><span class="p">(</span><span class="n">get_aug_params</span><span class="p">(</span><span class="n">shear</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="mi">180</span><span class="p">)</span>
    743. <span class="n">shear_y</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">tan</span><span class="p">(</span><span class="n">get_aug_params</span><span class="p">(</span><span class="n">shear</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="mi">180</span><span class="p">)</span>
    744. <span class="n">M</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">R</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">shear_y</span> <span class="o">*</span> <span class="n">R</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    745. <span class="n">M</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">R</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">shear_x</span> <span class="o">*</span> <span class="n">R</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    746. <span class="c1"># Translation</span>
    747. <span class="n">translation_x</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">translate</span><span class="p">)</span> <span class="o">*</span> <span class="n">twidth</span> <span class="c1"># x translation (pixels)</span>
    748. <span class="n">translation_y</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">translate</span><span class="p">)</span> <span class="o">*</span> <span class="n">theight</span> <span class="c1"># y translation (pixels)</span>
    749. <span class="n">M</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">translation_x</span>
    750. <span class="n">M</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">translation_y</span>
    751. <span class="k">return</span> <span class="n">M</span><span class="p">,</span> <span class="n">scale</span>
    752. <span class="k">def</span> <span class="nf">apply_affine_to_bboxes</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">targets_seg</span><span class="p">,</span> <span class="n">target_size</span><span class="p">,</span> <span class="n">M</span><span class="p">):</span>
    753. <span class="n">num_gts</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
    754. <span class="n">twidth</span><span class="p">,</span> <span class="n">theight</span> <span class="o">=</span> <span class="n">target_size</span>
    755. <span class="c1"># targets_seg = [B x w x h]</span>
    756. <span class="c1"># if any is_not_nan in axis = 1</span>
    757. <span class="n">seg_is_present_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_or</span><span class="o">.</span><span class="n">reduce</span><span class="p">(</span><span class="o">~</span><span class="n">np</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">targets_seg</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    758. <span class="n">num_gts_masks</span> <span class="o">=</span> <span class="n">seg_is_present_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
    759. <span class="n">num_gts_boxes</span> <span class="o">=</span> <span class="n">num_gts</span> <span class="o">-</span> <span class="n">num_gts_masks</span>
    760. <span class="k">if</span> <span class="n">num_gts_boxes</span><span class="p">:</span>
    761. <span class="c1"># warp corner points</span>
    762. <span class="n">corner_points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">num_gts_boxes</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
    763. <span class="c1"># x1y1, x2y2, x1y2, x2y1</span>
    764. <span class="n">corner_points</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="o">~</span><span class="n">seg_is_present_mask</span><span class="p">][:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_boxes</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    765. <span class="n">corner_points</span> <span class="o">=</span> <span class="n">corner_points</span> <span class="o">@</span> <span class="n">M</span><span class="o">.</span><span class="n">T</span> <span class="c1"># apply affine transform</span>
    766. <span class="n">corner_points</span> <span class="o">=</span> <span class="n">corner_points</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_boxes</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
    767. <span class="c1"># create new boxes</span>
    768. <span class="n">corner_xs</span> <span class="o">=</span> <span class="n">corner_points</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
    769. <span class="n">corner_ys</span> <span class="o">=</span> <span class="n">corner_points</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
    770. <span class="n">new_bboxes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">corner_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">corner_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">corner_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">corner_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
    771. <span class="k">else</span><span class="p">:</span>
    772. <span class="n">new_bboxes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
    773. <span class="k">if</span> <span class="n">num_gts_masks</span><span class="p">:</span>
    774. <span class="c1"># warp segmentation points</span>
    775. <span class="n">num_seg_points</span> <span class="o">=</span> <span class="n">targets_seg</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
    776. <span class="n">corner_points_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">num_gts_masks</span> <span class="o">*</span> <span class="n">num_seg_points</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
    777. <span class="n">corner_points_seg</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets_seg</span><span class="p">[</span><span class="n">seg_is_present_mask</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_masks</span> <span class="o">*</span> <span class="n">num_seg_points</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    778. <span class="n">corner_points_seg</span> <span class="o">=</span> <span class="n">corner_points_seg</span> <span class="o">@</span> <span class="n">M</span><span class="o">.</span><span class="n">T</span>
    779. <span class="n">corner_points_seg</span> <span class="o">=</span> <span class="n">corner_points_seg</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_masks</span><span class="p">,</span> <span class="n">num_seg_points</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span>
    780. <span class="c1"># create new boxes</span>
    781. <span class="n">seg_points_xs</span> <span class="o">=</span> <span class="n">corner_points_seg</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
    782. <span class="n">seg_points_ys</span> <span class="o">=</span> <span class="n">corner_points_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
    783. <span class="n">new_tight_bboxes</span> <span class="o">=</span> <span class="p">(</span>
    784. <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">nanmin</span><span class="p">(</span><span class="n">seg_points_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">nanmin</span><span class="p">(</span><span class="n">seg_points_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">nanmax</span><span class="p">(</span><span class="n">seg_points_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">nanmax</span><span class="p">(</span><span class="n">seg_points_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
    785. <span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    786. <span class="o">.</span><span class="n">T</span>
    787. <span class="p">)</span>
    788. <span class="k">else</span><span class="p">:</span>
    789. <span class="n">new_tight_bboxes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
    790. <span class="n">targets</span><span class="p">[</span><span class="o">~</span><span class="n">seg_is_present_mask</span><span class="p">,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_bboxes</span>
    791. <span class="n">targets</span><span class="p">[</span><span class="n">seg_is_present_mask</span><span class="p">,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_tight_bboxes</span>
    792. <span class="c1"># clip boxes</span>
    793. <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">twidth</span><span class="p">)</span>
    794. <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">theight</span><span class="p">)</span>
    795. <span class="k">return</span> <span class="n">targets</span>
    796. <span class="k">def</span> <span class="nf">random_affine</span><span class="p">(</span>
    797. <span class="n">img</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
    798. <span class="n">targets</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="p">(),</span>
    799. <span class="n">targets_seg</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    800. <span class="n">target_size</span><span class="p">:</span> <span class="nb">tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">640</span><span class="p">,</span> <span class="mi">640</span><span class="p">),</span>
    801. <span class="n">degrees</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
    802. <span class="n">translate</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    803. <span class="n">scales</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    804. <span class="n">shear</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
    805. <span class="n">filter_box_candidates</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    806. <span class="n">wh_thr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
    807. <span class="n">ar_thr</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
    808. <span class="n">area_thr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
    809. <span class="p">):</span>
    810. <span class="sd">&quot;&quot;&quot;</span>
    811. <span class="sd"> Performs random affine transform to img, targets</span>
    812. <span class="sd"> :param img: Input image</span>
    813. <span class="sd"> :param targets: Input target</span>
    814. <span class="sd"> :param targets_seg: Targets derived from segmentation masks</span>
    815. <span class="sd"> :param target_size: Desired output shape</span>
    816. <span class="sd"> :param degrees: Degrees for random rotation, when float the random values are drawn uniformly</span>
    817. <span class="sd"> from (-degrees, degrees).</span>
    818. <span class="sd"> :param translate: Translate size (in pixels) for random translation, when float the random values</span>
    819. <span class="sd"> are drawn uniformly from (-translate, translate)</span>
    820. <span class="sd"> :param scales: Values for random rescale, when float the random values are drawn uniformly</span>
    821. <span class="sd"> from (0.1-scales, 0.1+scales)</span>
    822. <span class="sd"> :param shear: Degrees for random shear, when float the random values are drawn uniformly</span>
    823. <span class="sd"> from (shear, shear)</span>
    824. <span class="sd"> :param filter_box_candidates: whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.</span>
    825. <span class="sd"> :param wh_thr: (float) edge size threshold when filter_box_candidates = True. Bounding oxes with edges smaller</span>
    826. <span class="sd"> then this values will be filtered out. (default=2)</span>
    827. <span class="sd"> :param ar_thr: (float) aspect ratio threshold filter_box_candidates = True. Bounding boxes with aspect ratio larger</span>
    828. <span class="sd"> then this values will be filtered out. (default=20)</span>
    829. <span class="sd"> :param area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.</span>
    830. <span class="sd"> Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)</span>
    831. <span class="sd"> :return: Image and Target with applied random affine</span>
    832. <span class="sd"> &quot;&quot;&quot;</span>
    833. <span class="n">targets_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">0</span><span class="p">))</span> <span class="k">if</span> <span class="n">targets_seg</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">targets_seg</span>
    834. <span class="n">M</span><span class="p">,</span> <span class="n">scale</span> <span class="o">=</span> <span class="n">get_affine_matrix</span><span class="p">(</span><span class="n">target_size</span><span class="p">,</span> <span class="n">degrees</span><span class="p">,</span> <span class="n">translate</span><span class="p">,</span> <span class="n">scales</span><span class="p">,</span> <span class="n">shear</span><span class="p">)</span>
    835. <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">warpAffine</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">dsize</span><span class="o">=</span><span class="n">target_size</span><span class="p">,</span> <span class="n">borderValue</span><span class="o">=</span><span class="p">(</span><span class="mi">114</span><span class="p">,</span> <span class="mi">114</span><span class="p">,</span> <span class="mi">114</span><span class="p">))</span>
    836. <span class="c1"># Transform label coordinates</span>
    837. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    838. <span class="n">targets_orig</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    839. <span class="n">targets</span> <span class="o">=</span> <span class="n">apply_affine_to_bboxes</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">targets_seg</span><span class="p">,</span> <span class="n">target_size</span><span class="p">,</span> <span class="n">M</span><span class="p">)</span>
    840. <span class="k">if</span> <span class="n">filter_box_candidates</span><span class="p">:</span>
    841. <span class="n">box_candidates_ids</span> <span class="o">=</span> <span class="n">_filter_box_candidates</span><span class="p">(</span><span class="n">targets_orig</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">wh_thr</span><span class="o">=</span><span class="n">wh_thr</span><span class="p">,</span> <span class="n">ar_thr</span><span class="o">=</span><span class="n">ar_thr</span><span class="p">,</span> <span class="n">area_thr</span><span class="o">=</span><span class="n">area_thr</span><span class="p">)</span>
    842. <span class="n">targets</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">box_candidates_ids</span><span class="p">]</span>
    843. <span class="k">return</span> <span class="n">img</span><span class="p">,</span> <span class="n">targets</span>
    844. <span class="k">def</span> <span class="nf">_filter_box_candidates</span><span class="p">(</span><span class="n">box1</span><span class="p">,</span> <span class="n">box2</span><span class="p">,</span> <span class="n">wh_thr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">ar_thr</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">area_thr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
    845. <span class="sd">&quot;&quot;&quot;</span>
    846. <span class="sd"> compute candidate boxes</span>
    847. <span class="sd"> :param box1: before augment</span>
    848. <span class="sd"> :param box2: after augment</span>
    849. <span class="sd"> :param wh_thr: wh_thr (pixels)</span>
    850. <span class="sd"> :param ar_thr: aspect_ratio_thr</span>
    851. <span class="sd"> :param area_thr: area_ratio</span>
    852. <span class="sd"> :return:</span>
    853. <span class="sd"> &quot;&quot;&quot;</span>
    854. <span class="n">box1</span> <span class="o">=</span> <span class="n">box1</span><span class="o">.</span><span class="n">T</span>
    855. <span class="n">box2</span> <span class="o">=</span> <span class="n">box2</span><span class="o">.</span><span class="n">T</span>
    856. <span class="n">w1</span><span class="p">,</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    857. <span class="n">w2</span><span class="p">,</span> <span class="n">h2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box2</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    858. <span class="n">ar</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">w2</span> <span class="o">/</span> <span class="p">(</span><span class="n">h2</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">),</span> <span class="n">h2</span> <span class="o">/</span> <span class="p">(</span><span class="n">w2</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">))</span> <span class="c1"># aspect ratio</span>
    859. <span class="k">return</span> <span class="p">(</span><span class="n">w2</span> <span class="o">&gt;</span> <span class="n">wh_thr</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">h2</span> <span class="o">&gt;</span> <span class="n">wh_thr</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">w2</span> <span class="o">*</span> <span class="n">h2</span> <span class="o">/</span> <span class="p">(</span><span class="n">w1</span> <span class="o">*</span> <span class="n">h1</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">area_thr</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ar</span> <span class="o">&lt;</span> <span class="n">ar_thr</span><span class="p">)</span> <span class="c1"># candidates</span>
    860. <span class="k">def</span> <span class="nf">_mirror</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">boxes</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
    861. <span class="sd">&quot;&quot;&quot;</span>
    862. <span class="sd"> Horizontal flips image and bboxes with probability prob.</span>
    863. <span class="sd"> :param image: (np.array) image to be flipped.</span>
    864. <span class="sd"> :param boxes: (np.array) bboxes to be modified.</span>
    865. <span class="sd"> :param prob: probability to perform flipping.</span>
    866. <span class="sd"> :return: flipped_image, flipped_bboxes</span>
    867. <span class="sd"> &quot;&quot;&quot;</span>
    868. <span class="n">flipped_boxes</span> <span class="o">=</span> <span class="n">boxes</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    869. <span class="n">_</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span>
    870. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">prob</span><span class="p">:</span>
    871. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="p">[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    872. <span class="n">flipped_boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">width</span> <span class="o">-</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">::</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span>
    873. <span class="k">return</span> <span class="n">image</span><span class="p">,</span> <span class="n">flipped_boxes</span>
    874. <span class="k">def</span> <span class="nf">augment_hsv</span><span class="p">(</span><span class="n">img</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">hgain</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sgain</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">vgain</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)):</span>
    875. <span class="n">hsv_augs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="p">[</span><span class="n">hgain</span><span class="p">,</span> <span class="n">sgain</span><span class="p">,</span> <span class="n">vgain</span><span class="p">]</span> <span class="c1"># random gains</span>
    876. <span class="n">hsv_augs</span> <span class="o">*=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># random selection of h, s, v</span>
    877. <span class="n">hsv_augs</span> <span class="o">=</span> <span class="n">hsv_augs</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int16</span><span class="p">)</span>
    878. <span class="n">img_hsv</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="p">],</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2HSV</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int16</span><span class="p">)</span>
    879. <span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">hsv_augs</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">%</span> <span class="mi">180</span>
    880. <span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">hsv_augs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span>
    881. <span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">hsv_augs</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span>
    882. <span class="n">img</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="p">]</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">img_hsv</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_HSV2BGR</span><span class="p">)</span> <span class="c1"># no return needed</span>
    883. <span class="k">def</span> <span class="nf">rescale_and_pad_to_size</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">swap</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">pad_val</span><span class="o">=</span><span class="mi">114</span><span class="p">):</span>
    884. <span class="sd">&quot;&quot;&quot;</span>
    885. <span class="sd"> Rescales image according to minimum ratio between the target height /image height, target width / image width,</span>
    886. <span class="sd"> and pads the image to the target size.</span>
    887. <span class="sd"> :param img: Image to be rescaled</span>
    888. <span class="sd"> :param input_size: Target size</span>
    889. <span class="sd"> :param swap: Axis&#39;s to be rearranged.</span>
    890. <span class="sd"> :return: rescaled image, ratio</span>
    891. <span class="sd"> &quot;&quot;&quot;</span>
    892. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
    893. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="n">pad_val</span>
    894. <span class="k">else</span><span class="p">:</span>
    895. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="n">pad_val</span>
    896. <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    897. <span class="n">resized_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
    898. <span class="n">img</span><span class="p">,</span>
    899. <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">)),</span>
    900. <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">,</span>
    901. <span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
    902. <span class="n">padded_img</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">)]</span> <span class="o">=</span> <span class="n">resized_img</span>
    903. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">swap</span><span class="p">)</span>
    904. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">padded_img</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    905. <span class="k">return</span> <span class="n">padded_img</span><span class="p">,</span> <span class="n">r</span>
    906. </pre></div>
    907. </div>
    908. </div>
    909. <footer>
    910. <hr/>
    911. <div role="contentinfo">
    912. <p>&#169; Copyright 2021, SuperGradients team.</p>
    913. </div>
    914. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    915. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    916. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    917. </footer>
    918. </div>
    919. </div>
    920. </section>
    921. </div>
    922. <script>
    923. jQuery(function () {
    924. SphinxRtdTheme.Navigation.enable(true);
    925. });
    926. </script>
    927. </body>
    928. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    559
    560
    561
    562
    563
    564
    565
    566
    567
    568
    569
    570
    571
    572
    573
    574
    575
    576
    577
    578
    579
    580
    581
    582
    583
    584
    585
    586
    587
    588
    589
    590
    591
    592
    593
    594
    595
    596
    597
    598
    599
    600
    601
    602
    603
    604
    605
    606
    607
    608
    609
    610
    611
    612
    613
    614
    615
    616
    617
    618
    619
    620
    621
    622
    623
    624
    625
    626
    627
    628
    629
    630
    631
    632
    633
    634
    635
    636
    637
    638
    639
    640
    641
    642
    643
    644
    645
    646
    647
    648
    649
    650
    651
    652
    653
    654
    655
    656
    657
    658
    659
    660
    661
    662
    663
    664
    665
    666
    667
    668
    669
    670
    671
    672
    673
    674
    675
    676
    677
    678
    679
    680
    681
    682
    683
    684
    685
    686
    687
    688
    689
    690
    691
    692
    693
    694
    695
    696
    697
    698
    699
    700
    701
    702
    703
    704
    705
    706
    707
    708
    709
    710
    711
    712
    713
    714
    715
    716
    717
    718
    719
    720
    721
    722
    723
    724
    725
    726
    727
    728
    729
    730
    731
    732
    733
    734
    735
    736
    737
    738
    739
    740
    741
    742
    743
    744
    745
    746
    747
    748
    749
    750
    751
    752
    753
    754
    755
    756
    757
    758
    759
    760
    761
    762
    763
    764
    765
    766
    767
    768
    769
    770
    771
    772
    773
    774
    775
    776
    777
    778
    779
    780
    781
    782
    783
    784
    785
    786
    787
    788
    789
    790
    791
    792
    793
    794
    795
    796
    797
    798
    799
    800
    801
    802
    803
    804
    805
    806
    807
    808
    809
    810
    811
    812
    813
    814
    815
    816
    817
    818
    819
    820
    821
    822
    823
    824
    825
    826
    827
    828
    829
    830
    831
    832
    833
    834
    835
    836
    837
    838
    839
    840
    841
    842
    843
    844
    845
    846
    847
    848
    849
    850
    851
    852
    853
    854
    855
    856
    857
    858
    859
    860
    861
    862
    863
    864
    865
    866
    867
    868
    869
    870
    871
    872
    873
    874
    875
    876
    877
    878
    879
    880
    881
    882
    883
    884
    885
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.callbacks &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.callbacks</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.callbacks</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">copy</span>
    84. <span class="kn">import</span> <span class="nn">os</span>
    85. <span class="kn">import</span> <span class="nn">time</span>
    86. <span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
    87. <span class="kn">import</span> <span class="nn">math</span>
    88. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    89. <span class="kn">import</span> <span class="nn">onnx</span>
    90. <span class="kn">import</span> <span class="nn">onnxruntime</span>
    91. <span class="kn">import</span> <span class="nn">torch</span>
    92. <span class="kn">import</span> <span class="nn">signal</span>
    93. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
    94. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    95. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionVisualization</span><span class="p">,</span> <span class="n">DetectionPostPredictionCallback</span>
    96. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.segmentation_utils</span> <span class="kn">import</span> <span class="n">BinarySegmentationVisualization</span>
    97. <span class="kn">import</span> <span class="nn">cv2</span>
    98. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    99. <span class="k">try</span><span class="p">:</span>
    100. <span class="kn">from</span> <span class="nn">deci_lab_client.client</span> <span class="kn">import</span> <span class="n">DeciPlatformClient</span>
    101. <span class="kn">from</span> <span class="nn">deci_lab_client.models</span> <span class="kn">import</span> <span class="n">ModelBenchmarkState</span>
    102. <span class="kn">from</span> <span class="nn">deci_lab_client.models.model_metadata</span> <span class="kn">import</span> <span class="n">ModelMetadata</span>
    103. <span class="n">_imported_deci_lab_failure</span> <span class="o">=</span> <span class="kc">None</span>
    104. <span class="k">except</span> <span class="p">(</span><span class="ne">ImportError</span><span class="p">,</span> <span class="ne">NameError</span><span class="p">,</span> <span class="ne">ModuleNotFoundError</span><span class="p">)</span> <span class="k">as</span> <span class="n">import_err</span><span class="p">:</span>
    105. <span class="n">logger</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;Failed to import deci_lab_client&quot;</span><span class="p">)</span>
    106. <span class="n">_imported_deci_lab_failure</span> <span class="o">=</span> <span class="n">import_err</span>
    107. <div class="viewcode-block" id="Phase"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.Phase">[docs]</a><span class="k">class</span> <span class="nc">Phase</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
    108. <span class="n">PRE_TRAINING</span> <span class="o">=</span> <span class="s2">&quot;PRE_TRAINING&quot;</span>
    109. <span class="n">TRAIN_BATCH_END</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_BATCH_END&quot;</span>
    110. <span class="n">TRAIN_BATCH_STEP</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_BATCH_STEP&quot;</span>
    111. <span class="n">TRAIN_EPOCH_START</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_EPOCH_START&quot;</span>
    112. <span class="n">TRAIN_EPOCH_END</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_EPOCH_END&quot;</span>
    113. <span class="n">VALIDATION_BATCH_END</span> <span class="o">=</span> <span class="s2">&quot;VALIDATION_BATCH_END&quot;</span>
    114. <span class="n">VALIDATION_EPOCH_END</span> <span class="o">=</span> <span class="s2">&quot;VALIDATION_EPOCH_END&quot;</span>
    115. <span class="n">VALIDATION_END_BEST_EPOCH</span> <span class="o">=</span> <span class="s2">&quot;VALIDATION_END_BEST_EPOCH&quot;</span>
    116. <span class="n">TEST_BATCH_END</span> <span class="o">=</span> <span class="s2">&quot;TEST_BATCH_END&quot;</span>
    117. <span class="n">TEST_END</span> <span class="o">=</span> <span class="s2">&quot;TEST_END&quot;</span>
    118. <span class="n">POST_TRAINING</span> <span class="o">=</span> <span class="s2">&quot;POST_TRAINING&quot;</span></div>
    119. <div class="viewcode-block" id="ContextSgMethods"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ContextSgMethods">[docs]</a><span class="k">class</span> <span class="nc">ContextSgMethods</span><span class="p">:</span>
    120. <span class="sd">&quot;&quot;&quot;</span>
    121. <span class="sd"> Class for delegating SgModel&#39;s methods, so that only the relevant ones are (&quot;phase wise&quot;) are accessible.</span>
    122. <span class="sd"> &quot;&quot;&quot;</span>
    123. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">methods</span><span class="p">):</span>
    124. <span class="k">for</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span> <span class="ow">in</span> <span class="n">methods</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    125. <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span><span class="p">)</span></div>
    126. <div class="viewcode-block" id="PhaseContext"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseContext">[docs]</a><span class="k">class</span> <span class="nc">PhaseContext</span><span class="p">:</span>
    127. <span class="sd">&quot;&quot;&quot;</span>
    128. <span class="sd"> Represents the input for phase callbacks, and is constantly updated after callback calls.</span>
    129. <span class="sd"> &quot;&quot;&quot;</span>
    130. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">batch_idx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metrics_dict</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">preds</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    131. <span class="n">target</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metrics_compute_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">loss_avg_meter</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">loss_log_items</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">criterion</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    132. <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">experiment_name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ckpt_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">net</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">lr_warmup_epochs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sg_logger</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    133. <span class="n">train_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">valid_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    134. <span class="n">training_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ddp_silent_mode</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">architecture</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    135. <span class="n">arch_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metric_idx_in_results_tuple</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    136. <span class="n">metric_to_watch</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">valid_metrics</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">context_methods</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    137. <span class="bp">self</span><span class="o">.</span><span class="n">epoch</span> <span class="o">=</span> <span class="n">epoch</span>
    138. <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">=</span> <span class="n">batch_idx</span>
    139. <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
    140. <span class="bp">self</span><span class="o">.</span><span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span>
    141. <span class="bp">self</span><span class="o">.</span><span class="n">preds</span> <span class="o">=</span> <span class="n">preds</span>
    142. <span class="bp">self</span><span class="o">.</span><span class="n">target</span> <span class="o">=</span> <span class="n">target</span>
    143. <span class="bp">self</span><span class="o">.</span><span class="n">metrics_dict</span> <span class="o">=</span> <span class="n">metrics_dict</span>
    144. <span class="bp">self</span><span class="o">.</span><span class="n">metrics_compute_fn</span> <span class="o">=</span> <span class="n">metrics_compute_fn</span>
    145. <span class="bp">self</span><span class="o">.</span><span class="n">loss_avg_meter</span> <span class="o">=</span> <span class="n">loss_avg_meter</span>
    146. <span class="bp">self</span><span class="o">.</span><span class="n">loss_log_items</span> <span class="o">=</span> <span class="n">loss_log_items</span>
    147. <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span> <span class="o">=</span> <span class="n">criterion</span>
    148. <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
    149. <span class="bp">self</span><span class="o">.</span><span class="n">stop_training</span> <span class="o">=</span> <span class="kc">False</span>
    150. <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span> <span class="o">=</span> <span class="n">experiment_name</span>
    151. <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_dir</span> <span class="o">=</span> <span class="n">ckpt_dir</span>
    152. <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">net</span>
    153. <span class="bp">self</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">=</span> <span class="n">lr_warmup_epochs</span>
    154. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="n">sg_logger</span>
    155. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">train_loader</span>
    156. <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="o">=</span> <span class="n">valid_loader</span>
    157. <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">training_params</span>
    158. <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span> <span class="o">=</span> <span class="n">ddp_silent_mode</span>
    159. <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="o">=</span> <span class="n">checkpoint_params</span>
    160. <span class="bp">self</span><span class="o">.</span><span class="n">architecture</span> <span class="o">=</span> <span class="n">architecture</span>
    161. <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="o">=</span> <span class="n">arch_params</span>
    162. <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span> <span class="o">=</span> <span class="n">metric_idx_in_results_tuple</span>
    163. <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="n">metric_to_watch</span>
    164. <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="n">valid_metrics</span>
    165. <span class="bp">self</span><span class="o">.</span><span class="n">context_methods</span> <span class="o">=</span> <span class="n">context_methods</span>
    166. <div class="viewcode-block" id="PhaseContext.update_context"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseContext.update_context">[docs]</a> <span class="k">def</span> <span class="nf">update_context</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    167. <span class="k">for</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    168. <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span><span class="p">)</span></div></div>
    169. <div class="viewcode-block" id="PhaseCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseCallback">[docs]</a><span class="k">class</span> <span class="nc">PhaseCallback</span><span class="p">:</span>
    170. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
    171. <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">=</span> <span class="n">phase</span>
    172. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    173. <span class="k">raise</span> <span class="ne">NotImplementedError</span>
    174. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    175. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span></div>
    176. <div class="viewcode-block" id="ModelConversionCheckCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ModelConversionCheckCallback">[docs]</a><span class="k">class</span> <span class="nc">ModelConversionCheckCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    177. <span class="sd">&quot;&quot;&quot;</span>
    178. <span class="sd"> Pre-training callback that verifies model conversion to onnx given specified conversion parameters.</span>
    179. <span class="sd"> The model is converted, then inference is applied with onnx runtime.</span>
    180. <span class="sd"> Use this callback wit hthe same args as DeciPlatformCallback to prevent conversion fails at the end of training.</span>
    181. <span class="sd"> Attributes:</span>
    182. <span class="sd"> model_meta_data: (ModelMetadata) model&#39;s meta-data object.</span>
    183. <span class="sd"> The following parameters may be passed as kwargs in order to control the conversion to onnx:</span>
    184. <span class="sd"> :param opset_version (default=11)</span>
    185. <span class="sd"> :param do_constant_folding (default=True)</span>
    186. <span class="sd"> :param dynamic_axes (default=</span>
    187. <span class="sd"> {&#39;input&#39;: {0: &#39;batch_size&#39;},</span>
    188. <span class="sd"> # Variable length axes</span>
    189. <span class="sd"> &#39;output&#39;: {0: &#39;batch_size&#39;}}</span>
    190. <span class="sd"> )</span>
    191. <span class="sd"> :param input_names (default=[&quot;input&quot;])</span>
    192. <span class="sd"> :param output_names (default=[&quot;output&quot;])</span>
    193. <span class="sd"> :param rtol (default=1e-03)</span>
    194. <span class="sd"> :param atol (default=1e-05)</span>
    195. <span class="sd"> &quot;&quot;&quot;</span>
    196. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_meta_data</span><span class="p">:</span> <span class="n">ModelMetadata</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    197. <span class="nb">super</span><span class="p">(</span><span class="n">ModelConversionCheckCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">PRE_TRAINING</span><span class="p">)</span>
    198. <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span> <span class="o">=</span> <span class="n">model_meta_data</span>
    199. <span class="bp">self</span><span class="o">.</span><span class="n">opset_version</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;opset_version&quot;</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
    200. <span class="bp">self</span><span class="o">.</span><span class="n">do_constant_folding</span> <span class="o">=</span> <span class="p">(</span>
    201. <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;do_constant_folding&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="k">if</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;do_constant_folding&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="k">else</span> <span class="kc">True</span>
    202. <span class="p">)</span>
    203. <span class="bp">self</span><span class="o">.</span><span class="n">input_names</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;input_names&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="p">[</span><span class="s2">&quot;input&quot;</span><span class="p">]</span>
    204. <span class="bp">self</span><span class="o">.</span><span class="n">output_names</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;output_names&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="p">[</span><span class="s2">&quot;output&quot;</span><span class="p">]</span>
    205. <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_axes</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;dynamic_axes&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="p">{</span><span class="s2">&quot;input&quot;</span><span class="p">:</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="s2">&quot;batch_size&quot;</span><span class="p">},</span> <span class="s2">&quot;output&quot;</span><span class="p">:</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="s2">&quot;batch_size&quot;</span><span class="p">}}</span>
    206. <span class="bp">self</span><span class="o">.</span><span class="n">rtol</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;rtol&quot;</span><span class="p">,</span> <span class="mf">1e-03</span><span class="p">)</span>
    207. <span class="bp">self</span><span class="o">.</span><span class="n">atol</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;atol&quot;</span><span class="p">,</span> <span class="mf">1e-05</span><span class="p">)</span>
    208. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    209. <span class="n">model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">)</span>
    210. <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
    211. <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="c1"># Put model into eval mode</span>
    212. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">&quot;prep_model_for_conversion&quot;</span><span class="p">):</span>
    213. <span class="n">model</span><span class="o">.</span><span class="n">prep_model_for_conversion</span><span class="p">(</span><span class="n">input_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">input_dimensions</span><span class="p">)</span>
    214. <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span>
    215. <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">primary_batch_size</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">input_dimensions</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span>
    216. <span class="p">)</span>
    217. <span class="n">tmp_model_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">ckpt_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">name</span> <span class="o">+</span> <span class="s2">&quot;_tmp.onnx&quot;</span><span class="p">)</span>
    218. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
    219. <span class="n">torch_out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    220. <span class="n">torch</span><span class="o">.</span><span class="n">onnx</span><span class="o">.</span><span class="n">export</span><span class="p">(</span>
    221. <span class="n">model</span><span class="p">,</span> <span class="c1"># Model being run</span>
    222. <span class="n">x</span><span class="p">,</span> <span class="c1"># Model input (or a tuple for multiple inputs)</span>
    223. <span class="n">tmp_model_path</span><span class="p">,</span> <span class="c1"># Where to save the model (can be a file or file-like object)</span>
    224. <span class="n">export_params</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># Store the trained parameter weights inside the model file</span>
    225. <span class="n">opset_version</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">opset_version</span><span class="p">,</span>
    226. <span class="n">do_constant_folding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">do_constant_folding</span><span class="p">,</span>
    227. <span class="n">input_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">input_names</span><span class="p">,</span>
    228. <span class="n">output_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">output_names</span><span class="p">,</span>
    229. <span class="n">dynamic_axes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dynamic_axes</span><span class="p">,</span>
    230. <span class="p">)</span>
    231. <span class="n">onnx_model</span> <span class="o">=</span> <span class="n">onnx</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">tmp_model_path</span><span class="p">)</span>
    232. <span class="n">onnx</span><span class="o">.</span><span class="n">checker</span><span class="o">.</span><span class="n">check_model</span><span class="p">(</span><span class="n">onnx_model</span><span class="p">)</span>
    233. <span class="n">ort_session</span> <span class="o">=</span> <span class="n">onnxruntime</span><span class="o">.</span><span class="n">InferenceSession</span><span class="p">(</span>
    234. <span class="n">tmp_model_path</span><span class="p">,</span> <span class="n">providers</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;CUDAExecutionProvider&quot;</span><span class="p">,</span> <span class="s2">&quot;CPUExecutionProvider&quot;</span><span class="p">]</span>
    235. <span class="p">)</span>
    236. <span class="c1"># compute ONNX Runtime output prediction</span>
    237. <span class="n">ort_inputs</span> <span class="o">=</span> <span class="p">{</span><span class="n">ort_session</span><span class="o">.</span><span class="n">get_inputs</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()}</span>
    238. <span class="n">ort_outs</span> <span class="o">=</span> <span class="n">ort_session</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="n">ort_inputs</span><span class="p">)</span>
    239. <span class="c1"># TODO: Ideally we don&#39;t want to check this but have the certainty of just calling torch_out.cpu()</span>
    240. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">torch_out</span><span class="p">,</span> <span class="n">List</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">torch_out</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
    241. <span class="n">torch_out</span> <span class="o">=</span> <span class="n">torch_out</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    242. <span class="c1"># compare ONNX Runtime and PyTorch results</span>
    243. <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">torch_out</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">ort_outs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">rtol</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">atol</span><span class="p">)</span>
    244. <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">tmp_model_path</span><span class="p">)</span>
    245. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Exported model has been tested with ONNXRuntime, and the result looks good!&quot;</span><span class="p">)</span></div>
    246. <div class="viewcode-block" id="DeciLabUploadCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback">[docs]</a><span class="k">class</span> <span class="nc">DeciLabUploadCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    247. <span class="sd">&quot;&quot;&quot;</span>
    248. <span class="sd"> Post-training callback for uploading and optimizing a model.</span>
    249. <span class="sd"> Attributes:</span>
    250. <span class="sd"> email: (str) username for Deci platform.</span>
    251. <span class="sd"> model_meta_data: (ModelMetadata) model&#39;s meta-data object.</span>
    252. <span class="sd"> optimization_request_form: (dict) optimization request form object.</span>
    253. <span class="sd"> password: (str) default=None, should only be used for testing.</span>
    254. <span class="sd"> ckpt_name: (str) default=&quot;ckpt_best&quot; refers to the filename of the checkpoint, inside the checkpoint directory.</span>
    255. <span class="sd"> The following parameters may be passed as kwargs in order to control the conversion to onnx:</span>
    256. <span class="sd"> :param opset_version</span>
    257. <span class="sd"> :param do_constant_folding</span>
    258. <span class="sd"> :param dynamic_axes</span>
    259. <span class="sd"> :param input_names</span>
    260. <span class="sd"> :param output_names</span>
    261. <span class="sd"> &quot;&quot;&quot;</span>
    262. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_meta_data</span><span class="p">,</span> <span class="n">optimization_request_form</span><span class="p">,</span> <span class="n">auth_token</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="o">=</span><span class="s2">&quot;ckpt_best.pth&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    263. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">POST_TRAINING</span><span class="p">)</span>
    264. <span class="k">if</span> <span class="n">_imported_deci_lab_failure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    265. <span class="k">raise</span> <span class="n">_imported_deci_lab_failure</span>
    266. <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span> <span class="o">=</span> <span class="n">model_meta_data</span>
    267. <span class="bp">self</span><span class="o">.</span><span class="n">optimization_request_form</span> <span class="o">=</span> <span class="n">optimization_request_form</span>
    268. <span class="bp">self</span><span class="o">.</span><span class="n">conversion_kwargs</span> <span class="o">=</span> <span class="n">kwargs</span>
    269. <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">ckpt_name</span>
    270. <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span> <span class="o">=</span> <span class="n">DeciPlatformClient</span><span class="p">(</span><span class="s2">&quot;api.deci.ai&quot;</span><span class="p">,</span> <span class="mi">443</span><span class="p">,</span> <span class="n">https</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    271. <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span><span class="o">.</span><span class="n">login</span><span class="p">(</span><span class="n">token</span><span class="o">=</span><span class="n">auth_token</span><span class="p">)</span>
    272. <div class="viewcode-block" id="DeciLabUploadCallback.log_optimization_failed"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback.log_optimization_failed">[docs]</a> <span class="nd">@staticmethod</span>
    273. <span class="k">def</span> <span class="nf">log_optimization_failed</span><span class="p">():</span>
    274. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;We couldn&#39;t finish your model optimization. Visit https://console.deci.ai for details&quot;</span><span class="p">)</span></div>
    275. <div class="viewcode-block" id="DeciLabUploadCallback.upload_model"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback.upload_model">[docs]</a> <span class="k">def</span> <span class="nf">upload_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
    276. <span class="sd">&quot;&quot;&quot;</span>
    277. <span class="sd"> This function will upload the trained model to the Deci Lab</span>
    278. <span class="sd"> Args:</span>
    279. <span class="sd"> model: The resulting model from the training process</span>
    280. <span class="sd"> &quot;&quot;&quot;</span>
    281. <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span><span class="o">.</span><span class="n">add_model</span><span class="p">(</span>
    282. <span class="n">add_model_request</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="p">,</span>
    283. <span class="n">optimization_request</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">optimization_request_form</span><span class="p">,</span>
    284. <span class="n">local_loaded_model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span>
    285. <span class="p">)</span></div>
    286. <div class="viewcode-block" id="DeciLabUploadCallback.get_optimization_status"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback.get_optimization_status">[docs]</a> <span class="k">def</span> <span class="nf">get_optimization_status</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimized_model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    287. <span class="sd">&quot;&quot;&quot;</span>
    288. <span class="sd"> This function will do fetch the optimized version of the trained model and check on its benchmark status.</span>
    289. <span class="sd"> The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes</span>
    290. <span class="sd"> or log about the successful optimization - whichever happens first.</span>
    291. <span class="sd"> Args:</span>
    292. <span class="sd"> optimized_model_name (str): Optimized model name</span>
    293. <span class="sd"> Returns:</span>
    294. <span class="sd"> bool: whether or not the optimized model has been benchmarked</span>
    295. <span class="sd"> &quot;&quot;&quot;</span>
    296. <span class="k">def</span> <span class="nf">handler</span><span class="p">(</span><span class="n">_signum</span><span class="p">,</span> <span class="n">_frame</span><span class="p">):</span>
    297. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;Process timed out. Visit https://console.deci.ai for details&quot;</span><span class="p">)</span>
    298. <span class="k">return</span> <span class="kc">False</span>
    299. <span class="n">signal</span><span class="o">.</span><span class="n">signal</span><span class="p">(</span><span class="n">signal</span><span class="o">.</span><span class="n">SIGALRM</span><span class="p">,</span> <span class="n">handler</span><span class="p">)</span>
    300. <span class="n">signal</span><span class="o">.</span><span class="n">alarm</span><span class="p">(</span><span class="mi">1800</span><span class="p">)</span>
    301. <span class="n">finished</span> <span class="o">=</span> <span class="kc">False</span>
    302. <span class="k">while</span> <span class="ow">not</span> <span class="n">finished</span><span class="p">:</span>
    303. <span class="n">optimized_model</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span><span class="o">.</span><span class="n">get_model_by_name</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">optimized_model_name</span><span class="p">)</span><span class="o">.</span><span class="n">data</span>
    304. <span class="k">if</span> <span class="n">optimized_model</span><span class="o">.</span><span class="n">benchmark_state</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">ModelBenchmarkState</span><span class="o">.</span><span class="n">IN_PROGRESS</span><span class="p">,</span> <span class="n">ModelBenchmarkState</span><span class="o">.</span><span class="n">PENDING</span><span class="p">]:</span>
    305. <span class="n">finished</span> <span class="o">=</span> <span class="kc">True</span>
    306. <span class="k">else</span><span class="p">:</span>
    307. <span class="n">time</span><span class="o">.</span><span class="n">sleep</span><span class="p">(</span><span class="mi">30</span><span class="p">)</span>
    308. <span class="n">signal</span><span class="o">.</span><span class="n">alarm</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    309. <span class="k">return</span> <span class="kc">True</span></div>
    310. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    311. <span class="sd">&quot;&quot;&quot;</span>
    312. <span class="sd"> This function will attempt to upload the trained model and schedule an optimization for it.</span>
    313. <span class="sd"> Args:</span>
    314. <span class="sd"> context (PhaseContext): Training phase context</span>
    315. <span class="sd"> Returns:</span>
    316. <span class="sd"> bool: whether or not the optimized model has been benchmarked</span>
    317. <span class="sd"> &quot;&quot;&quot;</span>
    318. <span class="k">try</span><span class="p">:</span>
    319. <span class="n">model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">net</span><span class="p">)</span>
    320. <span class="n">model_state_dict_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">ckpt_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span><span class="p">)</span>
    321. <span class="n">model_state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">model_state_dict_path</span><span class="p">)[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span>
    322. <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="o">=</span><span class="n">model_state_dict</span><span class="p">)</span>
    323. <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
    324. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">&quot;prep_model_for_conversion&quot;</span><span class="p">):</span>
    325. <span class="n">model</span><span class="o">.</span><span class="n">prep_model_for_conversion</span><span class="p">(</span><span class="n">input_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">input_dimensions</span><span class="p">)</span>
    326. <span class="bp">self</span><span class="o">.</span><span class="n">upload_model</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span>
    327. <span class="n">model_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">name</span>
    328. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Successfully added </span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2"> to the model repository&quot;</span><span class="p">)</span>
    329. <span class="n">optimized_model_name</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2">_1_1&quot;</span>
    330. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;We&#39;ll wait for the scheduled optimization to finish. Please don&#39;t close this window&quot;</span><span class="p">)</span>
    331. <span class="n">success</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_optimization_status</span><span class="p">(</span><span class="n">optimized_model_name</span><span class="o">=</span><span class="n">optimized_model_name</span><span class="p">)</span>
    332. <span class="k">if</span> <span class="n">success</span><span class="p">:</span>
    333. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Successfully finished your model optimization. Visit https://console.deci.ai for details&quot;</span><span class="p">)</span>
    334. <span class="k">else</span><span class="p">:</span>
    335. <span class="n">DeciLabUploadCallback</span><span class="o">.</span><span class="n">log_optimization_failed</span><span class="p">()</span>
    336. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
    337. <span class="n">DeciLabUploadCallback</span><span class="o">.</span><span class="n">log_optimization_failed</span><span class="p">()</span>
    338. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">ex</span><span class="p">)</span></div>
    339. <div class="viewcode-block" id="LRCallbackBase"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase">[docs]</a><span class="k">class</span> <span class="nc">LRCallbackBase</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    340. <span class="sd">&quot;&quot;&quot;</span>
    341. <span class="sd"> Base class for hard coded learning rate scheduling regimes, implemented as callbacks.</span>
    342. <span class="sd"> &quot;&quot;&quot;</span>
    343. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">,</span> <span class="n">initial_lr</span><span class="p">,</span> <span class="n">update_param_groups</span><span class="p">,</span> <span class="n">train_loader_len</span><span class="p">,</span> <span class="n">net</span><span class="p">,</span> <span class="n">training_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    344. <span class="nb">super</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
    345. <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">=</span> <span class="n">initial_lr</span>
    346. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">initial_lr</span>
    347. <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="n">update_param_groups</span>
    348. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">=</span> <span class="n">train_loader_len</span>
    349. <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">net</span>
    350. <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">training_params</span>
    351. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    352. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_lr_scheduling_enabled</span><span class="p">(</span><span class="n">context</span><span class="p">):</span>
    353. <span class="bp">self</span><span class="o">.</span><span class="n">perform_scheduling</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
    354. <div class="viewcode-block" id="LRCallbackBase.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    355. <span class="sd">&quot;&quot;&quot;</span>
    356. <span class="sd"> Predicate that controls whether to perform lr scheduling based on values in context.</span>
    357. <span class="sd"> @param context: PhaseContext: current phase&#39;s context.</span>
    358. <span class="sd"> @return: bool, whether to apply lr scheduling or not.</span>
    359. <span class="sd"> &quot;&quot;&quot;</span>
    360. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div>
    361. <div class="viewcode-block" id="LRCallbackBase.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    362. <span class="sd">&quot;&quot;&quot;</span>
    363. <span class="sd"> Performs lr scheduling based on values in context.</span>
    364. <span class="sd"> @param context: PhaseContext: current phase&#39;s context.</span>
    365. <span class="sd"> &quot;&quot;&quot;</span>
    366. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div>
    367. <div class="viewcode-block" id="LRCallbackBase.update_lr"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase.update_lr">[docs]</a> <span class="k">def</span> <span class="nf">update_lr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    368. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span><span class="p">:</span>
    369. <span class="n">param_groups</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">update_param_groups</span><span class="p">(</span>
    370. <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span>
    371. <span class="p">)</span>
    372. <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span> <span class="o">=</span> <span class="n">param_groups</span>
    373. <span class="k">else</span><span class="p">:</span>
    374. <span class="c1"># UPDATE THE OPTIMIZERS PARAMETER</span>
    375. <span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
    376. <span class="n">param_group</span><span class="p">[</span><span class="s2">&quot;lr&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span></div></div>
    377. <div class="viewcode-block" id="WarmupLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.WarmupLRCallback">[docs]</a><span class="k">class</span> <span class="nc">WarmupLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
    378. <span class="sd">&quot;&quot;&quot;</span>
    379. <span class="sd"> LR scheduling callback for linear step warmup.</span>
    380. <span class="sd"> LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None- LR climb starts from</span>
    381. <span class="sd"> initial_lr/(1+warmup_epochs).</span>
    382. <span class="sd"> &quot;&quot;&quot;</span>
    383. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    384. <span class="nb">super</span><span class="p">(</span><span class="n">WarmupLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_START</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    385. <span class="bp">self</span><span class="o">.</span><span class="n">warmup_initial_lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">warmup_initial_lr</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">/</span> <span class="p">(</span>
    386. <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">+</span> <span class="mi">1</span>
    387. <span class="p">)</span>
    388. <span class="bp">self</span><span class="o">.</span><span class="n">warmup_step_size</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_initial_lr</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
    389. <div class="viewcode-block" id="WarmupLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.WarmupLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    390. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_initial_lr</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_step_size</span>
    391. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
    392. <div class="viewcode-block" id="WarmupLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.WarmupLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    393. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&gt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span></div></div>
    394. <div class="viewcode-block" id="StepLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.StepLRCallback">[docs]</a><span class="k">class</span> <span class="nc">StepLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
    395. <span class="sd">&quot;&quot;&quot;</span>
    396. <span class="sd"> Hard coded step learning rate scheduling (i.e at specific milestones).</span>
    397. <span class="sd"> &quot;&quot;&quot;</span>
    398. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr_updates</span><span class="p">,</span> <span class="n">lr_decay_factor</span><span class="p">,</span> <span class="n">step_lr_update_freq</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    399. <span class="nb">super</span><span class="p">(</span><span class="n">StepLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_END</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    400. <span class="k">if</span> <span class="n">step_lr_update_freq</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">lr_updates</span><span class="p">):</span>
    401. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
    402. <span class="s2">&quot;Only one of [lr_updates, step_lr_update_freq] should be passed to StepLRCallback constructor&quot;</span>
    403. <span class="p">)</span>
    404. <span class="k">if</span> <span class="n">step_lr_update_freq</span><span class="p">:</span>
    405. <span class="n">max_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    406. <span class="n">warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
    407. <span class="n">lr_updates</span> <span class="o">=</span> <span class="p">[</span>
    408. <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">step_lr_update_freq</span> <span class="o">*</span> <span class="n">x</span><span class="p">))</span>
    409. <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">)</span>
    410. <span class="k">if</span> <span class="n">warmup_epochs</span> <span class="o">&lt;=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">step_lr_update_freq</span> <span class="o">*</span> <span class="n">x</span><span class="p">))</span> <span class="o">&lt;</span> <span class="n">max_epochs</span>
    411. <span class="p">]</span>
    412. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    413. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
    414. <span class="s2">&quot;Specific lr_updates were passed along with cooldown_epochs &gt; 0,&quot;</span> <span class="s2">&quot; cooldown will have no effect.&quot;</span>
    415. <span class="p">)</span>
    416. <span class="bp">self</span><span class="o">.</span><span class="n">lr_updates</span> <span class="o">=</span> <span class="n">lr_updates</span>
    417. <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">=</span> <span class="n">lr_decay_factor</span>
    418. <div class="viewcode-block" id="StepLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.StepLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    419. <span class="n">num_updates_passed</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_updates</span> <span class="k">if</span> <span class="n">x</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">]</span>
    420. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">**</span> <span class="nb">len</span><span class="p">(</span><span class="n">num_updates_passed</span><span class="p">)</span>
    421. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
    422. <div class="viewcode-block" id="StepLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.StepLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    423. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span></div></div>
    424. <div class="viewcode-block" id="ExponentialLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ExponentialLRCallback">[docs]</a><span class="k">class</span> <span class="nc">ExponentialLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
    425. <span class="sd">&quot;&quot;&quot;</span>
    426. <span class="sd"> Exponential decay learning rate scheduling. Decays the learning rate by `lr_decay_factor` every epoch.</span>
    427. <span class="sd"> &quot;&quot;&quot;</span>
    428. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr_decay_factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    429. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    430. <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">=</span> <span class="n">lr_decay_factor</span>
    431. <div class="viewcode-block" id="ExponentialLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ExponentialLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    432. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
    433. <span class="n">current_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_epoch</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span>
    434. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">**</span> <span class="p">(</span><span class="n">current_iter</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span><span class="p">)</span>
    435. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div>
    436. <div class="viewcode-block" id="ExponentialLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ExponentialLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    437. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    438. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div></div>
    439. <div class="viewcode-block" id="PolyLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PolyLRCallback">[docs]</a><span class="k">class</span> <span class="nc">PolyLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
    440. <span class="sd">&quot;&quot;&quot;</span>
    441. <span class="sd"> Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).</span>
    442. <span class="sd"> &quot;&quot;&quot;</span>
    443. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    444. <span class="nb">super</span><span class="p">(</span><span class="n">PolyLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    445. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">=</span> <span class="n">max_epochs</span>
    446. <div class="viewcode-block" id="PolyLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PolyLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    447. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
    448. <span class="n">effective_max_epochs</span> <span class="o">=</span> <span class="p">(</span>
    449. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    450. <span class="p">)</span>
    451. <span class="n">current_iter</span> <span class="o">=</span> <span class="p">(</span>
    452. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_epoch</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span>
    453. <span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">batch_accumulate</span>
    454. <span class="n">max_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_max_epochs</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">batch_accumulate</span>
    455. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">((</span><span class="mf">1.0</span> <span class="o">-</span> <span class="p">(</span><span class="n">current_iter</span> <span class="o">/</span> <span class="n">max_iter</span><span class="p">)),</span> <span class="mf">0.9</span><span class="p">)</span>
    456. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div>
    457. <div class="viewcode-block" id="PolyLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PolyLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    458. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    459. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div></div>
    460. <div class="viewcode-block" id="CosineLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CosineLRCallback">[docs]</a><span class="k">class</span> <span class="nc">CosineLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
    461. <span class="sd">&quot;&quot;&quot;</span>
    462. <span class="sd"> Hard coded step Cosine anealing learning rate scheduling.</span>
    463. <span class="sd"> &quot;&quot;&quot;</span>
    464. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">,</span> <span class="n">cosine_final_lr_ratio</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    465. <span class="nb">super</span><span class="p">(</span><span class="n">CosineLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    466. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">=</span> <span class="n">max_epochs</span>
    467. <span class="bp">self</span><span class="o">.</span><span class="n">cosine_final_lr_ratio</span> <span class="o">=</span> <span class="n">cosine_final_lr_ratio</span>
    468. <div class="viewcode-block" id="CosineLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CosineLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    469. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
    470. <span class="n">effective_max_epochs</span> <span class="o">=</span> <span class="p">(</span>
    471. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    472. <span class="p">)</span>
    473. <span class="n">current_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_epoch</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span>
    474. <span class="n">max_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_max_epochs</span>
    475. <span class="n">lr</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">current_iter</span> <span class="o">/</span> <span class="p">(</span><span class="n">max_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">))</span>
    476. <span class="c1"># the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch</span>
    477. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">cosine_final_lr_ratio</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cosine_final_lr_ratio</span><span class="p">)</span>
    478. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div>
    479. <div class="viewcode-block" id="CosineLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CosineLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    480. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    481. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div></div>
    482. <div class="viewcode-block" id="FunctionLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.FunctionLRCallback">[docs]</a><span class="k">class</span> <span class="nc">FunctionLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
    483. <span class="sd">&quot;&quot;&quot;</span>
    484. <span class="sd"> Hard coded rate scheduling for user defined lr scheduling function.</span>
    485. <span class="sd"> &quot;&quot;&quot;</span>
    486. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">,</span> <span class="n">lr_schedule_function</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    487. <span class="nb">super</span><span class="p">(</span><span class="n">FunctionLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    488. <span class="k">assert</span> <span class="n">callable</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lr_schedule_function</span><span class="p">),</span> <span class="s2">&quot;self.lr_function must be callable&quot;</span>
    489. <span class="bp">self</span><span class="o">.</span><span class="n">lr_schedule_function</span> <span class="o">=</span> <span class="n">lr_schedule_function</span>
    490. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">=</span> <span class="n">max_epochs</span>
    491. <div class="viewcode-block" id="FunctionLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.FunctionLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    492. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    493. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div>
    494. <div class="viewcode-block" id="FunctionLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.FunctionLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
    495. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
    496. <span class="n">effective_max_epochs</span> <span class="o">=</span> <span class="p">(</span>
    497. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
    498. <span class="p">)</span>
    499. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_schedule_function</span><span class="p">(</span>
    500. <span class="n">initial_lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span><span class="p">,</span>
    501. <span class="n">epoch</span><span class="o">=</span><span class="n">effective_epoch</span><span class="p">,</span>
    502. <span class="nb">iter</span><span class="o">=</span><span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">,</span>
    503. <span class="n">max_epoch</span><span class="o">=</span><span class="n">effective_max_epochs</span><span class="p">,</span>
    504. <span class="n">iters_per_epoch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span><span class="p">,</span>
    505. <span class="p">)</span>
    506. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div></div>
    507. <div class="viewcode-block" id="IllegalLRSchedulerMetric"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.IllegalLRSchedulerMetric">[docs]</a><span class="k">class</span> <span class="nc">IllegalLRSchedulerMetric</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    508. <span class="sd">&quot;&quot;&quot;Exception raised illegal combination of training parameters.</span>
    509. <span class="sd"> Attributes:</span>
    510. <span class="sd"> message -- explanation of the error</span>
    511. <span class="sd"> &quot;&quot;&quot;</span>
    512. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">metric_name</span><span class="p">,</span> <span class="n">metrics_dict</span><span class="p">):</span>
    513. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="p">(</span>
    514. <span class="s2">&quot;Illegal metric name: &quot;</span> <span class="o">+</span> <span class="n">metric_name</span> <span class="o">+</span> <span class="s2">&quot;. Expected one of metics_dics keys: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
    515. <span class="p">)</span>
    516. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
    517. <div class="viewcode-block" id="LRSchedulerCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRSchedulerCallback">[docs]</a><span class="k">class</span> <span class="nc">LRSchedulerCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    518. <span class="sd">&quot;&quot;&quot;</span>
    519. <span class="sd"> Learning rate scheduler callback.</span>
    520. <span class="sd"> Attributes:</span>
    521. <span class="sd"> scheduler: torch.optim._LRScheduler, the learning rate scheduler to be called step() with.</span>
    522. <span class="sd"> metric_name: str, (default=None) the metric name for ReduceLROnPlateau learning rate scheduler.</span>
    523. <span class="sd"> When passing __call__ a metrics_dict, with a key=self.metric_name, the value of that metric will monitored</span>
    524. <span class="sd"> for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).</span>
    525. <span class="sd"> &quot;&quot;&quot;</span>
    526. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scheduler</span><span class="p">,</span> <span class="n">phase</span><span class="p">,</span> <span class="n">metric_name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    527. <span class="nb">super</span><span class="p">(</span><span class="n">LRSchedulerCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
    528. <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span> <span class="o">=</span> <span class="n">scheduler</span>
    529. <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="o">=</span> <span class="n">metric_name</span>
    530. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    531. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">:</span>
    532. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="ow">in</span> <span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    533. <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span><span class="p">])</span>
    534. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    535. <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
    536. <span class="k">else</span><span class="p">:</span>
    537. <span class="k">raise</span> <span class="n">IllegalLRSchedulerMetric</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="p">)</span>
    538. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    539. <span class="k">return</span> <span class="s2">&quot;LRSchedulerCallback: &quot;</span> <span class="o">+</span> <span class="nb">repr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="p">)</span></div>
    540. <div class="viewcode-block" id="MetricsUpdateCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.MetricsUpdateCallback">[docs]</a><span class="k">class</span> <span class="nc">MetricsUpdateCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    541. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
    542. <span class="nb">super</span><span class="p">(</span><span class="n">MetricsUpdateCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
    543. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    544. <span class="n">context</span><span class="o">.</span><span class="n">metrics_compute_fn</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="o">**</span><span class="n">context</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
    545. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">criterion</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    546. <span class="n">context</span><span class="o">.</span><span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">loss_log_items</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">))</span></div>
    547. <div class="viewcode-block" id="KDModelMetricsUpdateCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.KDModelMetricsUpdateCallback">[docs]</a><span class="k">class</span> <span class="nc">KDModelMetricsUpdateCallback</span><span class="p">(</span><span class="n">MetricsUpdateCallback</span><span class="p">):</span>
    548. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
    549. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">phase</span><span class="p">)</span>
    550. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    551. <span class="n">metrics_compute_fn_kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">student_output</span> <span class="k">if</span> <span class="n">k</span> <span class="o">==</span> <span class="s2">&quot;preds&quot;</span> <span class="k">else</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">context</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
    552. <span class="n">context</span><span class="o">.</span><span class="n">metrics_compute_fn</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="o">**</span><span class="n">metrics_compute_fn_kwargs</span><span class="p">)</span>
    553. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">criterion</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    554. <span class="n">context</span><span class="o">.</span><span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">loss_log_items</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">))</span></div>
    555. <div class="viewcode-block" id="PhaseContextTestCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseContextTestCallback">[docs]</a><span class="k">class</span> <span class="nc">PhaseContextTestCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    556. <span class="sd">&quot;&quot;&quot;</span>
    557. <span class="sd"> A callback that saves the phase context the for testing.</span>
    558. <span class="sd"> &quot;&quot;&quot;</span>
    559. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
    560. <span class="nb">super</span><span class="p">(</span><span class="n">PhaseContextTestCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
    561. <span class="bp">self</span><span class="o">.</span><span class="n">context</span> <span class="o">=</span> <span class="kc">None</span>
    562. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    563. <span class="bp">self</span><span class="o">.</span><span class="n">context</span> <span class="o">=</span> <span class="n">context</span></div>
    564. <div class="viewcode-block" id="DetectionVisualizationCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DetectionVisualizationCallback">[docs]</a><span class="k">class</span> <span class="nc">DetectionVisualizationCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    565. <span class="sd">&quot;&quot;&quot;</span>
    566. <span class="sd"> A callback that adds a visualization of a batch of detection predictions to context.sg_logger</span>
    567. <span class="sd"> Attributes:</span>
    568. <span class="sd"> freq: frequency (in epochs) to perform this callback.</span>
    569. <span class="sd"> batch_idx: batch index to perform visualization for.</span>
    570. <span class="sd"> classes: class list of the dataset.</span>
    571. <span class="sd"> last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).</span>
    572. <span class="sd"> &quot;&quot;&quot;</span>
    573. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
    574. <span class="bp">self</span><span class="p">,</span>
    575. <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span>
    576. <span class="n">freq</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    577. <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span><span class="p">,</span>
    578. <span class="n">classes</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
    579. <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
    580. <span class="n">last_img_idx_in_batch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
    581. <span class="p">):</span>
    582. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionVisualizationCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
    583. <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">=</span> <span class="n">freq</span>
    584. <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
    585. <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">=</span> <span class="n">batch_idx</span>
    586. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
    587. <span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span> <span class="o">=</span> <span class="n">last_img_idx_in_batch</span>
    588. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    589. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">:</span>
    590. <span class="c1"># SOME CALCULATIONS ARE IN-PLACE IN NMS, SO CLONE THE PREDICTIONS</span>
    591. <span class="n">preds</span> <span class="o">=</span> <span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span><span class="p">)</span>
    592. <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span>
    593. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span>
    594. <span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span>
    595. <span class="p">)</span>
    596. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">)</span> <span class="k">for</span> <span class="n">image</span> <span class="ow">in</span> <span class="n">batch_imgs</span><span class="p">]</span>
    597. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_imgs</span><span class="p">)</span>
    598. <span class="n">tag</span> <span class="o">=</span> <span class="s2">&quot;batch_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;_images&quot;</span>
    599. <span class="n">context</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_images</span><span class="p">(</span>
    600. <span class="n">tag</span><span class="o">=</span><span class="n">tag</span><span class="p">,</span> <span class="n">images</span><span class="o">=</span><span class="n">batch_imgs</span><span class="p">[:</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span><span class="p">],</span> <span class="n">global_step</span><span class="o">=</span><span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">data_format</span><span class="o">=</span><span class="s2">&quot;NHWC&quot;</span>
    601. <span class="p">)</span></div>
    602. <div class="viewcode-block" id="BinarySegmentationVisualizationCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.BinarySegmentationVisualizationCallback">[docs]</a><span class="k">class</span> <span class="nc">BinarySegmentationVisualizationCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    603. <span class="sd">&quot;&quot;&quot;</span>
    604. <span class="sd"> A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger</span>
    605. <span class="sd"> Attributes:</span>
    606. <span class="sd"> freq: frequency (in epochs) to perform this callback.</span>
    607. <span class="sd"> batch_idx: batch index to perform visualization for.</span>
    608. <span class="sd"> last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).</span>
    609. <span class="sd"> &quot;&quot;&quot;</span>
    610. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">freq</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_img_idx_in_batch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
    611. <span class="nb">super</span><span class="p">(</span><span class="n">BinarySegmentationVisualizationCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
    612. <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">=</span> <span class="n">freq</span>
    613. <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">=</span> <span class="n">batch_idx</span>
    614. <span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span> <span class="o">=</span> <span class="n">last_img_idx_in_batch</span>
    615. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    616. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">:</span>
    617. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
    618. <span class="n">preds</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
    619. <span class="k">else</span><span class="p">:</span>
    620. <span class="n">preds</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
    621. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span>
    622. <span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span>
    623. <span class="p">)</span>
    624. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">)</span> <span class="k">for</span> <span class="n">image</span> <span class="ow">in</span> <span class="n">batch_imgs</span><span class="p">]</span>
    625. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_imgs</span><span class="p">)</span>
    626. <span class="n">tag</span> <span class="o">=</span> <span class="s2">&quot;batch_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;_images&quot;</span>
    627. <span class="n">context</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_images</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="n">tag</span><span class="p">,</span> <span class="n">images</span><span class="o">=</span><span class="n">batch_imgs</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span><span class="p">],</span>
    628. <span class="n">global_step</span><span class="o">=</span><span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">data_format</span><span class="o">=</span><span class="s1">&#39;NHWC&#39;</span><span class="p">)</span></div>
    629. <div class="viewcode-block" id="TrainingStageSwitchCallbackBase"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.TrainingStageSwitchCallbackBase">[docs]</a><span class="k">class</span> <span class="nc">TrainingStageSwitchCallbackBase</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    630. <span class="sd">&quot;&quot;&quot;</span>
    631. <span class="sd"> TrainingStageSwitchCallback</span>
    632. <span class="sd"> A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.</span>
    633. <span class="sd"> It does so by manipulating the objects inside the context.</span>
    634. <span class="sd"> Attributes:</span>
    635. <span class="sd"> next_stage_start_epoch: int, the epoch idx to apply the stage change.</span>
    636. <span class="sd"> &quot;&quot;&quot;</span>
    637. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">next_stage_start_epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    638. <span class="nb">super</span><span class="p">(</span><span class="n">TrainingStageSwitchCallbackBase</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_START</span><span class="p">)</span>
    639. <span class="bp">self</span><span class="o">.</span><span class="n">next_stage_start_epoch</span> <span class="o">=</span> <span class="n">next_stage_start_epoch</span>
    640. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    641. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_stage_start_epoch</span><span class="p">:</span>
    642. <span class="bp">self</span><span class="o">.</span><span class="n">apply_stage_change</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
    643. <div class="viewcode-block" id="TrainingStageSwitchCallbackBase.apply_stage_change"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.TrainingStageSwitchCallbackBase.apply_stage_change">[docs]</a> <span class="k">def</span> <span class="nf">apply_stage_change</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    644. <span class="sd">&quot;&quot;&quot;</span>
    645. <span class="sd"> This method is called when the callback is fired on the next_stage_start_epoch,</span>
    646. <span class="sd"> and holds the stage change logic that should be applied to the context&#39;s objects.</span>
    647. <span class="sd"> :param context: PhaseContext, context of current phase</span>
    648. <span class="sd"> &quot;&quot;&quot;</span>
    649. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div></div>
    650. <div class="viewcode-block" id="YoloXTrainingStageSwitchCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.YoloXTrainingStageSwitchCallback">[docs]</a><span class="k">class</span> <span class="nc">YoloXTrainingStageSwitchCallback</span><span class="p">(</span><span class="n">TrainingStageSwitchCallbackBase</span><span class="p">):</span>
    651. <span class="sd">&quot;&quot;&quot;</span>
    652. <span class="sd"> YoloXTrainingStageSwitchCallback</span>
    653. <span class="sd"> Training stage switch for YoloX training.</span>
    654. <span class="sd"> Disables mosaic, and manipulates YoloX loss to use L1.</span>
    655. <span class="sd"> &quot;&quot;&quot;</span>
    656. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">next_stage_start_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">285</span><span class="p">):</span>
    657. <span class="nb">super</span><span class="p">(</span><span class="n">YoloXTrainingStageSwitchCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">next_stage_start_epoch</span><span class="o">=</span><span class="n">next_stage_start_epoch</span><span class="p">)</span>
    658. <div class="viewcode-block" id="YoloXTrainingStageSwitchCallback.apply_stage_change"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.YoloXTrainingStageSwitchCallback.apply_stage_change">[docs]</a> <span class="k">def</span> <span class="nf">apply_stage_change</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    659. <span class="k">for</span> <span class="n">transform</span> <span class="ow">in</span> <span class="n">context</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">transforms</span><span class="p">:</span>
    660. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;close&quot;</span><span class="p">):</span>
    661. <span class="n">transform</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    662. <span class="nb">iter</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
    663. <span class="n">context</span><span class="o">.</span><span class="n">criterion</span><span class="o">.</span><span class="n">use_l1</span> <span class="o">=</span> <span class="kc">True</span></div></div>
    664. <div class="viewcode-block" id="CallbackHandler"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CallbackHandler">[docs]</a><span class="k">class</span> <span class="nc">CallbackHandler</span><span class="p">:</span>
    665. <span class="sd">&quot;&quot;&quot;</span>
    666. <span class="sd"> Runs all callbacks who&#39;s phase attribute equals to the given phase.</span>
    667. <span class="sd"> Attributes:</span>
    668. <span class="sd"> callbacks: List[PhaseCallback]. Callbacks to be run.</span>
    669. <span class="sd"> &quot;&quot;&quot;</span>
    670. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">callbacks</span><span class="p">):</span>
    671. <span class="bp">self</span><span class="o">.</span><span class="n">callbacks</span> <span class="o">=</span> <span class="n">callbacks</span>
    672. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    673. <span class="k">for</span> <span class="n">callback</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">callbacks</span><span class="p">:</span>
    674. <span class="k">if</span> <span class="n">callback</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">phase</span><span class="p">:</span>
    675. <span class="n">callback</span><span class="p">(</span><span class="n">context</span><span class="p">)</span></div>
    676. <span class="c1"># DICT FOR LEGACY LR HARD-CODED REGIMES, WILL BE DELETED IN THE FUTURE</span>
    677. <span class="n">LR_SCHEDULERS_CLS_DICT</span> <span class="o">=</span> <span class="p">{</span>
    678. <span class="s2">&quot;step&quot;</span><span class="p">:</span> <span class="n">StepLRCallback</span><span class="p">,</span>
    679. <span class="s2">&quot;poly&quot;</span><span class="p">:</span> <span class="n">PolyLRCallback</span><span class="p">,</span>
    680. <span class="s2">&quot;cosine&quot;</span><span class="p">:</span> <span class="n">CosineLRCallback</span><span class="p">,</span>
    681. <span class="s2">&quot;exp&quot;</span><span class="p">:</span> <span class="n">ExponentialLRCallback</span><span class="p">,</span>
    682. <span class="s2">&quot;function&quot;</span><span class="p">:</span> <span class="n">FunctionLRCallback</span><span class="p">,</span>
    683. <span class="p">}</span>
    684. <span class="n">LR_WARMUP_CLS_DICT</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;linear_step&quot;</span><span class="p">:</span> <span class="n">WarmupLRCallback</span><span class="p">}</span>
    685. <div class="viewcode-block" id="TestLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.TestLRCallback">[docs]</a><span class="k">class</span> <span class="nc">TestLRCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    686. <span class="sd">&quot;&quot;&quot;</span>
    687. <span class="sd"> Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In</span>
    688. <span class="sd"> the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first</span>
    689. <span class="sd"> one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.</span>
    690. <span class="sd"> &quot;&quot;&quot;</span>
    691. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr_placeholder</span><span class="p">):</span>
    692. <span class="nb">super</span><span class="p">(</span><span class="n">TestLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">VALIDATION_EPOCH_END</span><span class="p">)</span>
    693. <span class="bp">self</span><span class="o">.</span><span class="n">lr_placeholder</span> <span class="o">=</span> <span class="n">lr_placeholder</span>
    694. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    695. <span class="bp">self</span><span class="o">.</span><span class="n">lr_placeholder</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s2">&quot;lr&quot;</span><span class="p">])</span></div>
    696. </pre></div>
    697. </div>
    698. </div>
    699. <footer>
    700. <hr/>
    701. <div role="contentinfo">
    702. <p>&#169; Copyright 2021, SuperGradients team.</p>
    703. </div>
    704. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    705. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    706. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    707. </footer>
    708. </div>
    709. </div>
    710. </section>
    711. </div>
    712. <script>
    713. jQuery(function () {
    714. SphinxRtdTheme.Navigation.enable(true);
    715. });
    716. </script>
    717. </body>
    718. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.utils.checkpoint_utils &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.utils.checkpoint_utils &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -89,17 +91,40 @@
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span></span><span class="kn">import</span> <span class="nn">os</span>
     <span class="kn">import</span> <span class="nn">tempfile</span>
     <span class="kn">import</span> <span class="nn">tempfile</span>
     <span class="kn">import</span> <span class="nn">pkg_resources</span>
     <span class="kn">import</span> <span class="nn">pkg_resources</span>
    +
     <span class="kn">import</span> <span class="nn">torch</span>
     <span class="kn">import</span> <span class="nn">torch</span>
    +
    +<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">explicit_params_validation</span><span class="p">,</span> <span class="n">ADNNModelRepositoryDataInterfaces</span>
     <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">explicit_params_validation</span><span class="p">,</span> <span class="n">ADNNModelRepositoryDataInterfaces</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">MODEL_URLS</span>
     <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">MODEL_URLS</span>
    +<span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
    +
     <span class="k">try</span><span class="p">:</span>
     <span class="k">try</span><span class="p">:</span>
         <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">download_url_to_file</span><span class="p">,</span> <span class="n">load_state_dict_from_url</span>
         <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">download_url_to_file</span><span class="p">,</span> <span class="n">load_state_dict_from_url</span>
     <span class="k">except</span> <span class="p">(</span><span class="ne">ModuleNotFoundError</span><span class="p">,</span> <span class="ne">ImportError</span><span class="p">,</span> <span class="ne">NameError</span><span class="p">):</span>
     <span class="k">except</span> <span class="p">(</span><span class="ne">ModuleNotFoundError</span><span class="p">,</span> <span class="ne">ImportError</span><span class="p">,</span> <span class="ne">NameError</span><span class="p">):</span>
         <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">_download_url_to_file</span> <span class="k">as</span> <span class="n">download_url_to_file</span>
         <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">_download_url_to_file</span> <span class="k">as</span> <span class="n">download_url_to_file</span>
     
     
     
     
    -<div class="viewcode-block" id="get_ckpt_local_path"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_ckpt_local_path">[docs]</a><span class="k">def</span> <span class="nf">get_ckpt_local_path</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</spa
    -                        <span class="n">overwrite_local_checkpoint</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
    +<span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    +
    +
    +<span class="k">def</span> <span class="nf">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +    <span class="sd">&quot;&quot;&quot;Creating the checkpoint directory of a given experiment.</span>
    +<span class="sd">    :param experiment_name:     Name of the experiment.</span>
    +<span class="sd">    :param ckpt_root_dir:       Local root directory path where all experiment logging directories will</span>
    +<span class="sd">                                reside. When none is give, it is assumed that pkg_resources.resource_filename(&#39;checkpoints&#39;, &quot;&quot;)</span>
    +<span class="sd">                                exists and will be used.</span>
    +<span class="sd">    :return:                    checkpoints_dir_path</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +    <span class="k">if</span> <span class="n">ckpt_root_dir</span><span class="p">:</span>
    +        <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_root_dir</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">)</span>
    +    <span class="k">elif</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span><span class="p">):</span>
    +        <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">)</span>
    +    <span class="k">else</span><span class="p">:</span>
    +        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal checkpoints directory: pass ckpt_root_dir that exists, or add &#39;checkpoints&#39; to resources.&quot;</span><span class="p">)</span>
    +
    +
    +<span class="k">def</span> <span class="nf">get_ckpt_local_path</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">external_checkpoint_path</span><span class
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Gets the local path to the checkpoint file, which will be:</span>
     <span class="sd">    Gets the local path to the checkpoint file, which will be:</span>
     <span class="sd">        - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.</span>
     <span class="sd">        - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.</span>
    @@ -110,37 +135,20 @@
     <span class="sd">        - external_checkpoint_path when external_checkpoint_path != None</span>
     <span class="sd">        - external_checkpoint_path when external_checkpoint_path != None</span>
     
     
     <span class="sd">    @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.</span>
     <span class="sd">    @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.</span>
    -<span class="sd">    @param experiment_name: experiment name attr in sg_model</span>
    +<span class="sd">    @param experiment_name: experiment name attr in trainer</span>
     <span class="sd">    @param ckpt_name: checkpoint filename</span>
     <span class="sd">    @param ckpt_name: checkpoint filename</span>
    -<span class="sd">    @param model_checkpoints_location: S3, local ot URL</span>
     <span class="sd">    @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)</span>
     <span class="sd">    @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)</span>
    -<span class="sd">    @param overwrite_local_checkpoint: whether to overwrite the checkpoint file with the same name when downloading from S3.</span>
    -<span class="sd">    @param load_weights_only: whether to load the network&#39;s state dict only.</span>
     <span class="sd">    @return:</span>
     <span class="sd">    @return:</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    -    <span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">or</span> <span class="n">experiment_name</span>
    -    <span class="k">if</span> <span class="n">model_checkpoints_location</span> <span class="o">==</span> <span class="s1">&#39;local&#39;</span><span class="p">:</span>
    -        <span class="n">ckpt_local_path</span> <span class="o">=</span> <span class="n">external_checkpoint_path</span> <span class="ow">or</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="n">source_ckpt_folder_name</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span clas
    -
    -    <span class="c1"># COPY THE DATA FROM &#39;S3&#39;/&#39;URL&#39; INTO A LOCAL DIRECTORY</span>
    -    <span class="k">elif</span> <span class="n">model_checkpoints_location</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;s3&#39;</span><span class="p">)</span> <span class="ow">or</span> <span class="n">model_checkpoints_location</span> <span class="o">==</span> <span class="s1">&#39;url&#39;</span><span class="p">:</span>
    -        <span class="c1"># COPY REMOTE DATA TO A LOCAL DIRECTORY AND GET THAT DIRECTORYs NAME</span>
    -        <span class="n">ckpt_local_path</span> <span class="o">=</span> <span class="n">copy_ckpt_to_local_folder</span><span class="p">(</span><span class="n">local_ckpt_destination_dir</span><span class="o">=</span><span class="n">experiment_name</span><span class="p">,</span>
    -                                                    <span class="n">ckpt_filename</span><span class="o">=</span><span class="n">ckpt_name</span><span class="p">,</span>
    -                                                    <span class="n">remote_ckpt_source_dir</span><span class="o">=</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span>
    -                                                    <span class="n">path_src</span><span class="o">=</span><span class="n">model_checkpoints_location</span><span class="p">,</span>
    -                                                    <span class="n">overwrite_local_ckpt</span><span class="o">=</span><span class="n">overwrite_local_checkpoint</span><span class="p">,</span>
    -                                                    <span class="n">load_weights_only</span><span class="o">=</span><span class="n">load_weights_only</span><span class="p">)</span>
    -
    +    <span class="k">if</span> <span class="n">external_checkpoint_path</span><span class="p">:</span>
    +        <span class="k">return</span> <span class="n">external_checkpoint_path</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
    -        <span class="c1"># ERROR IN USER CODE FLOW - THIS WILL EVENTUALLY RAISE AN EXCEPTION</span>
    -        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
    -            <span class="s1">&#39;model_checkpoints_data_source: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model_checkpoints_location</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;not supported&#39;</span><span class="p">)</span>
    +        <span class="n">checkpoints_folder_name</span> <span class="o">=</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">or</span> <span class="n">experiment_name</span>
    +        <span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">checkpoints_folder_name</span><span class="p">)</span>
    +        <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoints_dir_path</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">)</span>
     
     
    -    <span class="k">return</span> <span class="n">ckpt_local_path</span></div>
     
     
    -
    -<div class="viewcode-block" id="adaptive_load_state_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.adaptive_load_state_dict">[docs]</a><span class="k">def</span> <span class="nf">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p"
    +<span class="k">def</span> <span class="nf">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">strict</span><span class="p">:</span> <span class="nb">str</span><s
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Adaptively loads state_dict to net, by adapting the state_dict to net&#39;s layer names first.</span>
     <span class="sd">    Adaptively loads state_dict to net, by adapting the state_dict to net&#39;s layer names first.</span>
     
     
    @@ -150,19 +158,24 @@
     <span class="sd">    @return:</span>
     <span class="sd">    @return:</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
         <span class="k">try</span><span class="p">:</span>
         <span class="k">try</span><span class="p">:</span>
    -        <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="n">strict</span><span class="p">)</span>
    +        <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="s2">&quot;net&quot;</span> <span class="ow">in</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="n">state_dict<
         <span class="k">except</span> <span class="p">(</span><span class="ne">RuntimeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">,</span> <span class="ne">KeyError</span><span class="p">)</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
         <span class="k">except</span> <span class="p">(</span><span class="ne">RuntimeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">,</span> <span class="ne">KeyError</span><span class="p">)</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
    -        <span class="k">if</span> <span class="n">strict</span> <span class="o">==</span> <span class="s1">&#39;no_key_matching&#39;</span><span class="p">:</span>
    +        <span class="k">if</span> <span class="n">strict</span> <span class="o">==</span> <span class="s2">&quot;no_key_matching&quot;</span><span class="p">:</span>
                 <span class="n">adapted_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">)</span>
                 <span class="n">adapted_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">)</span>
    -            <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    +            <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
    -            <span class="n">raise_informative_runtime_error</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">,</span> <span class="n">ex</span><span class="p">)</span></div>
    -
    -
    -<span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s1">&#39;None&#39;</span><span class="p">)</span>
    -<span class="k">def</span> <span class="nf">copy_ckpt_to_local_folder</span><span class="p">(</span><span class="n">local_ckpt_destination_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">remote_ckpt_source_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span cla
    -                              <span class="n">path_src</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;local&#39;</span><span class="p">,</span> <span class="n">overwrite_local_ckpt</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    -                              <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    +            <span class="n">raise_informative_runtime_error</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">,</span> <span class="n">ex</span><span class="p">)</span>
    +
    +
    +<span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s2">&quot;None&quot;</span><span class="p">)</span>
    +<span class="k">def</span> <span class="nf">copy_ckpt_to_local_folder</span><span class="p">(</span>
    +    <span class="n">local_ckpt_destination_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    +    <span class="n">ckpt_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    +    <span class="n">remote_ckpt_source_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    +    <span class="n">path_src</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;local&quot;</span><span class="p">,</span>
    +    <span class="n">overwrite_local_ckpt</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +    <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    +<span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Copy the checkpoint from any supported source to a local destination path</span>
     <span class="sd">    Copy the checkpoint from any supported source to a local destination path</span>
     <span class="sd">        :param local_ckpt_destination_dir:  destination where the checkpoint will be saved to</span>
     <span class="sd">        :param local_ckpt_destination_dir:  destination where the checkpoint will be saved to</span>
    @@ -181,27 +194,31 @@
         <span class="k">if</span> <span class="ow">not</span> <span class="n">overwrite_local_ckpt</span><span class="p">:</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">overwrite_local_ckpt</span><span class="p">:</span>
             <span class="c1"># CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO</span>
             <span class="c1"># CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO</span>
             <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">gettempdir</span><span class="p">()</span>
             <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">gettempdir</span><span class="p">()</span>
    -        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False &#39;</span>
    -              <span class="s1">&#39;-&gt; IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART&#39;</span><span class="p">)</span>
    +        <span class="nb">print</span><span class="p">(</span>
    +            <span class="s2">&quot;PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False &quot;</span>
    +            <span class="s2">&quot;-&gt; IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART&quot;</span>
    +        <span class="p">)</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="c1"># SAVE THE CHECKPOINT TO MODEL&#39;s FOLDER</span>
             <span class="c1"># SAVE THE CHECKPOINT TO MODEL&#39;s FOLDER</span>
    -        <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="n">local_ckpt_destination_dir</span><span class="p">)</span>
    +        <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;checkpoints&quot;</span><span class="p">,</span> <span class="n">local_ckpt_destination_dir</span><span class="p">)</span>
     
     
    -    <span class="k">if</span> <span class="n">path_src</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;s3&#39;</span><span class="p">):</span>
    +    <span class="k">if</span> <span class="n">path_src</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;s3&quot;</span><span class="p">):</span>
             <span class="n">model_checkpoints_data_interface</span> <span class="o">=</span> <span class="n">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">data_connection_location</span><span class="o">=</span><span class="n">path_src</span><span class="p">)</span>
             <span class="n">model_checkpoints_data_interface</span> <span class="o">=</span> <span class="n">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">data_connection_location</span><span class="o">=</span><span class="n">path_src</span><span class="p">)</span>
             <span class="c1"># DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER</span>
             <span class="c1"># DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER</span>
             <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_remote_checkpoints_file</span><span class="p">(</span>
             <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_remote_checkpoints_file</span><span class="p">(</span>
                 <span class="n">ckpt_source_remote_dir</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span>
                 <span class="n">ckpt_source_remote_dir</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span>
                 <span class="n">ckpt_destination_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span><span class="p">,</span>
                 <span class="n">ckpt_destination_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span><span class="p">,</span>
                 <span class="n">ckpt_file_name</span><span class="o">=</span><span class="n">ckpt_filename</span><span class="p">,</span>
                 <span class="n">ckpt_file_name</span><span class="o">=</span><span class="n">ckpt_filename</span><span class="p">,</span>
    -            <span class="n">overwrite_local_checkpoints_file</span><span class="o">=</span><span class="n">overwrite_local_ckpt</span><span class="p">)</span>
    +            <span class="n">overwrite_local_checkpoints_file</span><span class="o">=</span><span class="n">overwrite_local_ckpt</span><span class="p">,</span>
    +        <span class="p">)</span>
     
     
             <span class="k">if</span> <span class="ow">not</span> <span class="n">load_weights_only</span><span class="p">:</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="n">load_weights_only</span><span class="p">:</span>
                 <span class="c1"># COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT</span>
                 <span class="c1"># COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT</span>
    -            <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_all_remote_log_files</span><span class="p">(</span><span class="n">model_name</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span>
    -                                                                       <span class="n">model_checkpoint_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span><span class="p">)</span>
    +            <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_all_remote_log_files</span><span class="p">(</span>
    +                <span class="n">model_name</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">model_checkpoint_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span>
    +            <span class="p">)</span>
     
     
    -    <span class="k">if</span> <span class="n">path_src</span> <span class="o">==</span> <span class="s1">&#39;url&#39;</span><span class="p">:</span>
    +    <span class="k">if</span> <span class="n">path_src</span> <span class="o">==</span> <span class="s2">&quot;url&quot;</span><span class="p">:</span>
             <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">download_ckpt_destination_dir</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">ckpt_filename</span>
             <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">download_ckpt_destination_dir</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">ckpt_filename</span>
             <span class="c1"># DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER</span>
             <span class="c1"># DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER</span>
             <span class="n">download_url_to_file</span><span class="p">(</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">ckpt_file_full_local_path</span><span class="p">,</span> <span class="n">progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="n">download_url_to_file</span><span class="p">(</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">ckpt_file_full_local_path</span><span class="p">,</span> <span class="n">progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    @@ -209,20 +226,19 @@
         <span class="k">return</span> <span class="n">ckpt_file_full_local_path</span>
         <span class="k">return</span> <span class="n">ckpt_file_full_local_path</span>
     
     
     
     
    -<div class="viewcode-block" id="read_ckpt_state_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.read_ckpt_state_dict">[docs]</a><span class="k">def</span> <span class="nf">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span c
    +<span class="k">def</span> <span class="nf">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span class="p">):</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">):</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">):</span>
    -        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Incorrect Checkpoint path&#39;</span><span class="p">)</span>
    +        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Incorrect Checkpoint path&quot;</span><span class="p">)</span>
     
     
         <span class="k">if</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span><span class="p">:</span>
             <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">)</span>
             <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">)</span>
     
     
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="k">lambda</span> <span class="n">storage</span><span class="p">,</span> <span class="n">loc</span><span class="p">:</span> <span class="n">storage</span><span class="p">)</span>
             <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="k">lambda</span> <span class="n">storage</span><span class="p">,</span> <span class="n">loc</span><span class="p">:</span> <span class="n">storage</span><span class="p">)</span>
    -    <span class="k">return</span> <span class="n">state_dict</span></div>
    +    <span class="k">return</span> <span class="n">state_dict</span>
     
     
     
     
    -<div class="viewcode-block" id="adapt_state_dict_to_fit_model_layer_names"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.adapt_state_dict_to_fit_model_layer_names">[docs]</a><span class="k">def</span> <span class="nf">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">source_ckpt<
    -                                              <span class="n">exclude</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[],</span> <span class="n">solver</span><span class="p">:</span> <span class="n">callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    +<div class="viewcode-block" id="adapt_state_dict_to_fit_model_layer_names"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.adapt_state_dict_to_fit_model_layer_names">[docs]</a><span class="k">def</span> <span class="nf">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">source_ckpt</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit</span>
     <span class="sd">    Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit</span>
     <span class="sd">    the ckpt in order to properly load the weights into the model. If unsuccessful - returns None</span>
     <span class="sd">    the ckpt in order to properly load the weights into the model. If unsuccessful - returns None</span>
    @@ -233,39 +249,41 @@
     <span class="sd">                                        that returns a desired weight for ckpt_val.</span>
     <span class="sd">                                        that returns a desired weight for ckpt_val.</span>
     <span class="sd">        :return: renamed checkpoint dict (if possible)</span>
     <span class="sd">        :return: renamed checkpoint dict (if possible)</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    -    <span class="k">if</span> <span class="s1">&#39;net&#39;</span> <span class="ow">in</span> <span class="n">source_ckpt</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    -        <span class="n">source_ckpt</span> <span class="o">=</span> <span class="n">source_ckpt</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">]</span>
    +    <span class="k">if</span> <span class="s2">&quot;net&quot;</span> <span class="ow">in</span> <span class="n">source_ckpt</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    +        <span class="n">source_ckpt</span> <span class="o">=</span> <span class="n">source_ckpt</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span>
         <span class="n">model_state_dict_excluded</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">any</sp
         <span class="n">model_state_dict_excluded</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">any</sp
         <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="p">{}</span>
         <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="p">{}</span>
         <span class="k">for</span> <span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">),</span> <span class="p">(</span><span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">source_ckpt</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <s
         <span class="k">for</span> <span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">),</span> <span class="p">(</span><span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">source_ckpt</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <s
             <span class="k">if</span> <span class="n">solver</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">solver</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                 <span class="n">ckpt_val</span> <span class="o">=</span> <span class="n">solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span>
                 <span class="n">ckpt_val</span> <span class="o">=</span> <span class="n">solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span>
             <span class="k">if</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
    -            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;ckpt layer </span><span class="si">{</span><span class="n">ckpt_key</span><span class="si">}</span><span class="s1"> with shape </span><span class="si">{</span><span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1"> does not match </span><span class="si">{</span><span class="n">mode
    -                             <span class="sa">f</span><span class="s1">&#39; with shape </span><span class="si">{</span><span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1"> in the model&#39;</span><span class="p">)</span>
    +            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;ckpt layer </span><span class="si">{</span><span class="n">ckpt_key</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> does not match </span><span class="si">{</span><span class="n">mod
             <span class="n">new_ckpt_dict</span><span class="p">[</span><span class="n">model_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span>
             <span class="n">new_ckpt_dict</span><span class="p">[</span><span class="n">model_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span>
    -    <span class="k">return</span> <span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="n">new_ckpt_dict</span><span class="p">}</span></div>
    +    <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p">:</span> <span class="n">new_ckpt_dict</span><span class="p">}</span></div>
     
     
     
     
    -<div class="viewcode-block" id="raise_informative_runtime_error"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.raise_informative_runtime_error">[docs]</a><span class="k">def</span> <span class="nf">raise_informative_runtime_error</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">exception_msg</span><span class="p">):</spa
    +<div class="viewcode-block" id="raise_informative_runtime_error"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.raise_informative_runtime_error">[docs]</a><span class="k">def</span> <span class="nf">raise_informative_runtime_error</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">exception_msg</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Given a model state dict and source checkpoints, the method calls &quot;adapt_state_dict_to_fit_model_layer_names&quot;</span>
     <span class="sd">    Given a model state dict and source checkpoints, the method calls &quot;adapt_state_dict_to_fit_model_layer_names&quot;</span>
     <span class="sd">    and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible</span>
     <span class="sd">    and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
         <span class="k">try</span><span class="p">:</span>
         <span class="k">try</span><span class="p">:</span>
             <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span>
             <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span>
    -        <span class="n">temp_file</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">NamedTemporaryFile</span><span class="p">()</span><span class="o">.</span><span class="n">name</span> <span class="o">+</span> <span class="s1">&#39;.pt&#39;</span>
    +        <span class="n">temp_file</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">NamedTemporaryFile</span><span class="p">()</span><span class="o">.</span><span class="n">name</span> <span class="o">+</span> <span class="s2">&quot;.pt&quot;</span>
             <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">new_ckpt_dict</span><span class="p">,</span> <span class="n">temp_file</span><span class="p">)</span>
             <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">new_ckpt_dict</span><span class="p">,</span> <span class="n">temp_file</span><span class="p">)</span>
    -        <span class="n">exception_msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span><span class="si">}</span><span class="s2"> </spa
    -                        <span class="sa">f</span><span class="s2">&quot;model_layer_names method</span><span class="se">\n</span><span class="s2">a converted checkpoint file was saved in the path </span><span class="si">{</span><span class="n">temp_file</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2">&quot;</span>
    +        <span class="n">exception_msg</span> <span class="o">=</span> <span class="p">(</span>
    +            <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">convert ckpt 
    +            <span class="sa">f</span><span class="s2">&quot;model_layer_names method</span><span class="se">\n</span><span class="s2">a converted checkpoint file was saved in the path </span><span class="si">{</span><span class="n">temp_file</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2">&quot;</span>
    +        <span class="p">)</span>
         <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>  <span class="c1"># IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL</span>
         <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>  <span class="c1"># IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL</span>
             <span class="n">exception_msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">The checkpoint and model shapes do no fit, e.g.: </span><span class="si">{</span><span class="n">ex</span><span class="si">}</span><span class
             <span class="n">exception_msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">The checkpoint and model shapes do no fit, e.g.: </span><span class="si">{</span><span class="n">ex</span><span class="si">}</span><span class
         <span class="k">finally</span><span class="p">:</span>
         <span class="k">finally</span><span class="p">:</span>
             <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span></div>
             <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span></div>
     
     
     
     
    -<div class="viewcode-block" id="load_checkpoint_to_model"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.load_checkpoint_to_model">[docs]</a><span class="k">def</span> <span class="nf">load_checkpoint_to_model</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">load_backbone</span><span class="p">:</span> <span class="nb">boo
    -                             <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">load_ema_as_net</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">load_checkpoint_to_model</span><span class="p">(</span>
    +    <span class="n">ckpt_local_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">load_backbone</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">strict</span><span class="p">:</span> <span 
    +<span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Loads the state dict in ckpt_local_path to net and returns the checkpoint&#39;s state dict.</span>
     <span class="sd">    Loads the state dict in ckpt_local_path to net and returns the checkpoint&#39;s state dict.</span>
     
     
    @@ -278,35 +296,39 @@
     <span class="sd">    @return:</span>
     <span class="sd">    @return:</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
         <span class="k">if</span> <span class="n">ckpt_local_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">):</span>
         <span class="k">if</span> <span class="n">ckpt_local_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">):</span>
    -        <span class="n">error_msg</span> <span class="o">=</span> <span class="s1">&#39;Error - loading Model Checkpoint: Path </span><span class="si">{}</span><span class="s1"> does not exist&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">)</span>
    +        <span class="n">error_msg</span> <span class="o">=</span> <span class="s2">&quot;Error - loading Model Checkpoint: Path </span><span class="si">{}</span><span class="s2"> does not exist&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">error_msg</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">error_msg</span><span class="p">)</span>
     
     
    -    <span class="k">if</span> <span class="n">load_backbone</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;backbone&#39;</span><span class="p">):</span>
    +    <span class="k">if</span> <span class="n">load_backbone</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="s2">&quot;backbone&quot;</span><span class="p">):</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;No backbone attribute in net - Can&#39;t load backbone weights&quot;</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;No backbone attribute in net - Can&#39;t load backbone weights&quot;</span><span class="p">)</span>
     
     
         <span class="c1"># LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT</span>
         <span class="c1"># LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT</span>
         <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_local_path</span><span class="p">)</span>
         <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_local_path</span><span class="p">)</span>
     
     
         <span class="k">if</span> <span class="n">load_ema_as_net</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">load_ema_as_net</span><span class="p">:</span>
    -        <span class="k">if</span> <span class="s1">&#39;ema_net&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    +        <span class="k">if</span> <span class="s2">&quot;ema_net&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
                 <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Can&#39;t load ema network- no EMA network stored in checkpoint file&quot;</span><span class="p">)</span>
                 <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Can&#39;t load ema network- no EMA network stored in checkpoint file&quot;</span><span class="p">)</span>
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
    -            <span class="n">checkpoint</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">&#39;ema_net&#39;</span><span class="p">]</span>
    +            <span class="n">checkpoint</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span>
     
     
         <span class="c1"># LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL</span>
         <span class="c1"># LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL</span>
         <span class="k">if</span> <span class="n">load_backbone</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">load_backbone</span><span class="p">:</span>
    -        <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">backbone</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
    +        <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">backbone</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
             <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
     
     
    +    <span class="n">message_suffix</span> <span class="o">=</span> <span class="s2">&quot; checkpoint.&quot;</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">load_ema_as_net</span> <span class="k">else</span> <span class="s2">&quot; EMA checkpoint.&quot;</span>
    +    <span class="n">message_model</span> <span class="o">=</span> <span class="s2">&quot;model&quot;</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">load_backbone</span> <span class="k">else</span> <span class="s2">&quot;model&#39;s backbone&quot;</span>
    +    <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Successfully loaded &quot;</span> <span class="o">+</span> <span class="n">message_model</span> <span class="o">+</span> <span class="s2">&quot; weights from &quot;</span> <span class="o">+</span> <span class="n">ckpt_local_path</span> <span class="o">+</span> <span class="n">message_suffix</span><span class="p">)</span>
    +
         <span class="k">if</span> <span class="n">load_weights_only</span> <span class="ow">or</span> <span class="n">load_backbone</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">load_weights_only</span> <span class="ow">or</span> <span class="n">load_backbone</span><span class="p">:</span>
             <span class="c1"># DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS</span>
             <span class="c1"># DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS</span>
    -        <span class="p">[</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="k">if</span> <span class="n">key</span> <span class=
    +        <span class="p">[</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="k">if</span> <span class="n">key</span> <span class=
     
     
    -    <span class="k">return</span> <span class="n">checkpoint</span></div>
    +    <span class="k">return</span> <span class="n">checkpoint</span>
     
     
     
     
    -<div class="viewcode-block" id="MissingPretrainedWeightsException"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.MissingPretrainedWeightsException">[docs]</a><span class="k">class</span> <span class="nc">MissingPretrainedWeightsException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    +<span class="k">class</span> <span class="nc">MissingPretrainedWeightsException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Exception raised by unsupported pretrianed model.</span>
         <span class="sd">&quot;&quot;&quot;Exception raised by unsupported pretrianed model.</span>
     
     
     <span class="sd">    Attributes:</span>
     <span class="sd">    Attributes:</span>
    @@ -315,7 +337,7 @@
     
     
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Missing pretrained wights: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Missing pretrained wights: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
    -        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
    +        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span>
     
     
     
     
     <span class="k">def</span> <span class="nf">_yolox_ckpt_solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_yolox_ckpt_solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">):</span>
    @@ -323,8 +345,11 @@
     <span class="sd">    Helper method for reshaping old pretrained checkpoint&#39;s focus weights to 6x6 conv weights.</span>
     <span class="sd">    Helper method for reshaping old pretrained checkpoint&#39;s focus weights to 6x6 conv weights.</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
    -    <span class="k">if</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span> <span class="ow">and</span> <span class="n">ckpt_key</span> <span class="o">==</span> <span class="s1">&#39;module._backbone._modules_list.0.conv.conv.weight&#39;</span> <span class="ow">and</span> \
    -            <span class="n">model_key</span> <span class="o">==</span> <span class="s1">&#39;_backbone._modules_list.0.conv.weight&#39;</span><span class="p">:</span>
    +    <span class="k">if</span> <span class="p">(</span>
    +        <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span>
    +        <span class="ow">and</span> <span class="n">ckpt_key</span> <span class="o">==</span> <span class="s2">&quot;module._backbone._modules_list.0.conv.conv.weight&quot;</span>
    +        <span class="ow">and</span> <span class="n">model_key</span> <span class="o">==</span> <span class="s2">&quot;_backbone._modules_list.0.conv.weight&quot;</span>
    +    <span class="p">):</span>
             <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">3</span><span class="p">]</spa
             <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">3</span><span class="p">]</spa
             <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">:</sp
             <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">:</sp
             <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">6</span><span class="p">:</sp
             <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">6</span><span class="p">:</sp
    @@ -336,7 +361,7 @@
         <span class="k">return</span> <span class="n">replacement</span>
         <span class="k">return</span> <span class="n">replacement</span>
     
     
     
     
    -<div class="viewcode-block" id="load_pretrained_weights"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.load_pretrained_weights">[docs]</a><span class="k">def</span> <span class="nf">load_pretrained_weights</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">
    +<span class="k">def</span> <span class="nf">load_pretrained_weights</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pretrained_weights</span><span class="p">:</span> <span class="nb
     
     
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Loads pretrained weights from the MODEL_URLS dictionary to model</span>
     <span class="sd">    Loads pretrained weights from the MODEL_URLS dictionary to model</span>
    @@ -345,21 +370,41 @@
     <span class="sd">    @param pretrained_weights: name for the pretrianed weights (i.e imagenet)</span>
     <span class="sd">    @param pretrained_weights: name for the pretrianed weights (i.e imagenet)</span>
     <span class="sd">    @return: None</span>
     <span class="sd">    @return: None</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
    -    <span class="n">model_url_key</span> <span class="o">=</span> <span class="n">architecture</span> <span class="o">+</span> <span class="s1">&#39;_&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">)</span>
    +    <span class="n">model_url_key</span> <span class="o">=</span> <span class="n">architecture</span> <span class="o">+</span> <span class="s2">&quot;_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">)</span>
         <span class="k">if</span> <span class="n">model_url_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">MODEL_URLS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
         <span class="k">if</span> <span class="n">model_url_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">MODEL_URLS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
             <span class="k">raise</span> <span class="n">MissingPretrainedWeightsException</span><span class="p">(</span><span class="n">model_url_key</span><span class="p">)</span>
             <span class="k">raise</span> <span class="n">MissingPretrainedWeightsException</span><span class="p">(</span><span class="n">model_url_key</span><span class="p">)</span>
     
     
         <span class="n">url</span> <span class="o">=</span> <span class="n">MODEL_URLS</span><span class="p">[</span><span class="n">model_url_key</span><span class="p">]</span>
         <span class="n">url</span> <span class="o">=</span> <span class="n">MODEL_URLS</span><span class="p">[</span><span class="n">model_url_key</span><span class="p">]</span>
    -    <span class="n">unique_filename</span> <span class="o">=</span> <span class="n">url</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;https://deci-pretrained-models.s3.amazonaws.com/&quot;</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;/&#39;</span><span class="p">,</span> <span class="s1">&#39;_&#39;</spa
    -    <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="kc">None
    +    <span class="n">unique_filename</span> <span class="o">=</span> <span class="n">url</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;https://deci-pretrained-models.s3.amazonaws.com/&quot;</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">,</span> <span class="s2">&quot;_&quot;<
    +    <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
         <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">load_state_dict_from_url</span><span class="p">(</span><span class="n">url</span><span class="o">=</span><span class="n">url</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">,</span> <span class="n">file_name</span><span class="o">=</span><span class="n">unique_filename</span><span class="p">)</span>
         <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">load_state_dict_from_url</span><span class="p">(</span><span class="n">url</span><span class="o">=</span><span class="n">url</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">,</span> <span class="n">file_name</span><span class="o">=</span><span class="n">unique_filename</span><span class="p">)</span>
    -    <span class="k">if</span> <span class="s1">&#39;ema_net&#39;</span> <span class="ow">in</span> <span class="n">pretrained_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    -        <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s1">&#39;ema_net&#39;</span><span class="p">]</span>
    +    <span class="n">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">)</span>
    +
    +
    +<span class="k">def</span> <span class="nf">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">):</span>
    +    <span class="k">if</span> <span class="s2">&quot;ema_net&quot;</span> <span class="ow">in</span> <span class="n">pretrained_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    +        <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span>
         <span class="n">solver</span> <span class="o">=</span> <span class="n">_yolox_ckpt_solver</span> <span class="k">if</span> <span class="s2">&quot;yolox&quot;</span> <span class="ow">in</span> <span class="n">architecture</span> <span class="k">else</span> <span class="kc">None</span>
         <span class="n">solver</span> <span class="o">=</span> <span class="n">_yolox_ckpt_solver</span> <span class="k">if</span> <span class="s2">&quot;yolox&quot;</span> <span class="ow">in</span> <span class="n">architecture</span> <span class="k">else</span> <span class="kc">None</span>
    -    <span class="n">adapted_pretrained_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">model_state_dict</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
    -                                                                              <span class="n">source_ckpt</span><span class="o">=</span><span class="n">pretrained_state_dict</span><span class="p">,</span>
    -                                                                              <span class="n">solver</span><span class="o">=</span><span class="n">solver</span><span class="p">)</span>
    -    <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_pretrained_state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></div>
    +    <span class="n">adapted_pretrained_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span>
    +        <span class="n">model_state_dict</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">source_ckpt</span><span class="o">=</span><span class="n">pretrained_state_dict</span><span class="p">,</span> <span class="n">solver</span><span class="o">=</span><span class="n">solver</span>
    +    <span class="p">)</span>
    +    <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    +
    +
    +<span class="k">def</span> <span class="nf">load_pretrained_weights_local</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pretrained_weights</span><span class="p">:</span> <span cla
    +
    +    <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">    Loads pretrained weights from the MODEL_URLS dictionary to model</span>
    +<span class="sd">    @param architecture: name of the model&#39;s architecture</span>
    +<span class="sd">    @param model: model to load pretrinaed weights for</span>
    +<span class="sd">    @param pretrained_weights: path tp pretrained weights</span>
    +<span class="sd">    @return: None</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +
    +    <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    +
    +    <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">)</span>
    +    <span class="n">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">)</span>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -389,4 +434,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    450
    451
    452
    453
    454
    455
    456
    457
    458
    459
    460
    461
    462
    463
    464
    465
    466
    467
    468
    469
    470
    471
    472
    473
    474
    475
    476
    477
    478
    479
    480
    481
    482
    483
    484
    485
    486
    487
    488
    489
    490
    491
    492
    493
    494
    495
    496
    497
    498
    499
    500
    501
    502
    503
    504
    505
    506
    507
    508
    509
    510
    511
    512
    513
    514
    515
    516
    517
    518
    519
    520
    521
    522
    523
    524
    525
    526
    527
    528
    529
    530
    531
    532
    533
    534
    535
    536
    537
    538
    539
    540
    541
    542
    543
    544
    545
    546
    547
    548
    549
    550
    551
    552
    553
    554
    555
    556
    557
    558
    559
    560
    561
    562
    563
    564
    565
    566
    567
    568
    569
    570
    571
    572
    573
    574
    575
    576
    577
    578
    579
    580
    581
    582
    583
    584
    585
    586
    587
    588
    589
    590
    591
    592
    593
    594
    595
    596
    597
    598
    599
    600
    601
    602
    603
    604
    605
    606
    607
    608
    609
    610
    611
    612
    613
    614
    615
    616
    617
    618
    619
    620
    621
    622
    623
    624
    625
    626
    627
    628
    629
    630
    631
    632
    633
    634
    635
    636
    637
    638
    639
    640
    641
    642
    643
    644
    645
    646
    647
    648
    649
    650
    651
    652
    653
    654
    655
    656
    657
    658
    659
    660
    661
    662
    663
    664
    665
    666
    667
    668
    669
    670
    671
    672
    673
    674
    675
    676
    677
    678
    679
    680
    681
    682
    683
    684
    685
    686
    687
    688
    689
    690
    691
    692
    693
    694
    695
    696
    697
    698
    699
    700
    701
    702
    703
    704
    705
    706
    707
    708
    709
    710
    711
    712
    713
    714
    715
    716
    717
    718
    719
    720
    721
    722
    723
    724
    725
    726
    727
    728
    729
    730
    731
    732
    733
    734
    735
    736
    737
    738
    739
    740
    741
    742
    743
    744
    745
    746
    747
    748
    749
    750
    751
    752
    753
    754
    755
    756
    757
    758
    759
    760
    761
    762
    763
    764
    765
    766
    767
    768
    769
    770
    771
    772
    773
    774
    775
    776
    777
    778
    779
    780
    781
    782
    783
    784
    785
    786
    787
    788
    789
    790
    791
    792
    793
    794
    795
    796
    797
    798
    799
    800
    801
    802
    803
    804
    805
    806
    807
    808
    809
    810
    811
    812
    813
    814
    815
    816
    817
    818
    819
    820
    821
    822
    823
    824
    825
    826
    827
    828
    829
    830
    831
    832
    833
    834
    835
    836
    837
    838
    839
    840
    841
    842
    843
    844
    845
    846
    847
    848
    849
    850
    851
    852
    853
    854
    855
    856
    857
    858
    859
    860
    861
    862
    863
    864
    865
    866
    867
    868
    869
    870
    871
    872
    873
    874
    875
    876
    877
    878
    879
    880
    881
    882
    883
    884
    885
    886
    887
    888
    889
    890
    891
    892
    893
    894
    895
    896
    897
    898
    899
    900
    901
    902
    903
    904
    905
    906
    907
    908
    909
    910
    911
    912
    913
    914
    915
    916
    917
    918
    919
    920
    921
    922
    923
    924
    925
    926
    927
    928
    929
    930
    931
    932
    933
    934
    935
    936
    937
    938
    939
    940
    941
    942
    943
    944
    945
    946
    947
    948
    949
    950
    951
    952
    953
    954
    955
    956
    957
    958
    959
    960
    961
    962
    963
    964
    965
    966
    967
    968
    969
    970
    971
    972
    973
    974
    975
    976
    977
    978
    979
    980
    981
    982
    983
    984
    985
    986
    987
    988
    989
    990
    991
    992
    993
    994
    995
    996
    997
    998
    999
    1000
    1001
    1002
    1003
    1004
    1005
    1006
    1007
    1008
    1009
    1010
    1011
    1012
    1013
    1014
    1015
    1016
    1017
    1018
    1019
    1020
    1021
    1022
    1023
    1024
    1025
    1026
    1027
    1028
    1029
    1030
    1031
    1032
    1033
    1034
    1035
    1036
    1037
    1038
    1039
    1040
    1041
    1042
    1043
    1044
    1045
    1046
    1047
    1048
    1049
    1050
    1051
    1052
    1053
    1054
    1055
    1056
    1057
    1058
    1059
    1060
    1061
    1062
    1063
    1064
    1065
    1066
    1067
    1068
    1069
    1070
    1071
    1072
    1073
    1074
    1075
    1076
    1077
    1078
    1079
    1080
    1081
    1082
    1083
    1084
    1085
    1086
    1087
    1088
    1089
    1090
    1091
    1092
    1093
    1094
    1095
    1096
    1097
    1098
    1099
    1100
    1101
    1102
    1103
    1104
    1105
    1106
    1107
    1108
    1109
    1110
    1111
    1112
    1113
    1114
    1115
    1116
    1117
    1118
    1119
    1120
    1121
    1122
    1123
    1124
    1125
    1126
    1127
    1128
    1129
    1130
    1131
    1132
    1133
    1134
    1135
    1136
    1137
    1138
    1139
    1140
    1141
    1142
    1143
    1144
    1145
    1146
    1147
    1148
    1149
    1150
    1151
    1152
    1153
    1154
    1155
    1156
    1157
    1158
    1159
    1160
    1161
    1162
    1163
    1164
    1165
    1166
    1167
    1168
    1169
    1170
    1171
    1172
    1173
    1174
    1175
    1176
    1177
    1178
    1179
    1180
    1181
    1182
    1183
    1184
    1185
    1186
    1187
    1188
    1189
    1190
    1191
    1192
    1193
    1194
    1195
    1196
    1197
    1198
    1199
    1200
    1201
    1202
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.detection_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.detection_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.detection_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">math</span>
    84. <span class="kn">import</span> <span class="nn">os</span>
    85. <span class="kn">import</span> <span class="nn">pathlib</span>
    86. <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
    87. <span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
    88. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Dict</span>
    89. <span class="kn">import</span> <span class="nn">cv2</span>
    90. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
    91. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    92. <span class="kn">import</span> <span class="nn">torch</span>
    93. <span class="kn">import</span> <span class="nn">torchvision</span>
    94. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    95. <span class="kn">from</span> <span class="nn">torch.utils.data._utils.collate</span> <span class="kn">import</span> <span class="n">default_collate</span>
    96. <span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">ListConfig</span>
    97. <div class="viewcode-block" id="DetectionTargetsFormat"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionTargetsFormat">[docs]</a><span class="k">class</span> <span class="nc">DetectionTargetsFormat</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
    98. <span class="sd">&quot;&quot;&quot;</span>
    99. <span class="sd"> Enum class for the different detection output formats</span>
    100. <span class="sd"> When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).</span>
    101. <span class="sd"> For example:</span>
    102. <span class="sd"> LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]</span>
    103. <span class="sd"> &quot;&quot;&quot;</span>
    104. <span class="n">LABEL_XYXY</span> <span class="o">=</span> <span class="s2">&quot;LABEL_XYXY&quot;</span>
    105. <span class="n">XYXY_LABEL</span> <span class="o">=</span> <span class="s2">&quot;XYXY_LABEL&quot;</span>
    106. <span class="n">LABEL_NORMALIZED_XYXY</span> <span class="o">=</span> <span class="s2">&quot;LABEL_NORMALIZED_XYXY&quot;</span>
    107. <span class="n">NORMALIZED_XYXY_LABEL</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED_XYXY_LABEL&quot;</span>
    108. <span class="n">LABEL_CXCYWH</span> <span class="o">=</span> <span class="s2">&quot;LABEL_CXCYWH&quot;</span>
    109. <span class="n">CXCYWH_LABEL</span> <span class="o">=</span> <span class="s2">&quot;CXCYWH_LABEL&quot;</span>
    110. <span class="n">LABEL_NORMALIZED_CXCYWH</span> <span class="o">=</span> <span class="s2">&quot;LABEL_NORMALIZED_CXCYWH&quot;</span>
    111. <span class="n">NORMALIZED_CXCYWH_LABEL</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED_CXCYWH_LABEL&quot;</span></div>
    112. <div class="viewcode-block" id="get_cls_posx_in_target"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.get_cls_posx_in_target">[docs]</a><span class="k">def</span> <span class="nf">get_cls_posx_in_target</span><span class="p">(</span><span class="n">target_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
    113. <span class="sd">&quot;&quot;&quot;Get the label of a given target</span>
    114. <span class="sd"> :param target_format: Representation of the target (ex: LABEL_XYXY)</span>
    115. <span class="sd"> :return: Position of the class id in a bbox</span>
    116. <span class="sd"> ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label</span>
    117. <span class="sd"> &quot;&quot;&quot;</span>
    118. <span class="n">format_split</span> <span class="o">=</span> <span class="n">target_format</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)</span>
    119. <span class="k">if</span> <span class="n">format_split</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span><span class="p">:</span>
    120. <span class="k">return</span> <span class="mi">0</span>
    121. <span class="k">elif</span> <span class="n">format_split</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span><span class="p">:</span>
    122. <span class="k">return</span> <span class="o">-</span><span class="mi">1</span>
    123. <span class="k">else</span><span class="p">:</span>
    124. <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;No implementation to find index of LABEL in </span><span class="si">{</span><span class="n">target_format</span><span class="o">.</span><span class="n">value</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span></div>
    125. <span class="k">def</span> <span class="nf">_set_batch_labels_index</span><span class="p">(</span><span class="n">labels_batch</span><span class="p">):</span>
    126. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">labels</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">labels_batch</span><span class="p">):</span>
    127. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span>
    128. <span class="k">return</span> <span class="n">labels_batch</span>
    129. <div class="viewcode-block" id="convert_xywh_bbox_to_xyxy"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.convert_xywh_bbox_to_xyxy">[docs]</a><span class="k">def</span> <span class="nf">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    130. <span class="sd">&quot;&quot;&quot;</span>
    131. <span class="sd"> Converts bounding box format from [x, y, w, h] to [x1, y1, x2, y2]</span>
    132. <span class="sd"> :param input_bbox: input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for</span>
    133. <span class="sd"> boxes of a batch of images)</span>
    134. <span class="sd"> :return: Converted bbox in same dimensions as the original</span>
    135. <span class="sd"> &quot;&quot;&quot;</span>
    136. <span class="n">need_squeeze</span> <span class="o">=</span> <span class="kc">False</span>
    137. <span class="c1"># the input is always processed as a batch. in case it not a batch, it is unsqueezed, process and than squeeze back.</span>
    138. <span class="k">if</span> <span class="n">input_bbox</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">3</span><span class="p">:</span>
    139. <span class="n">need_squeeze</span> <span class="o">=</span> <span class="kc">True</span>
    140. <span class="n">input_bbox</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    141. <span class="n">converted_bbox</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">)</span>
    142. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    143. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    144. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    145. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    146. <span class="c1"># squeeze back if needed</span>
    147. <span class="k">if</span> <span class="n">need_squeeze</span><span class="p">:</span>
    148. <span class="n">converted_bbox</span> <span class="o">=</span> <span class="n">converted_bbox</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    149. <span class="k">return</span> <span class="n">converted_bbox</span></div>
    150. <span class="k">def</span> <span class="nf">_iou</span><span class="p">(</span><span class="n">CIoU</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">DIoU</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">GIoU</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
    151. <span class="sd">&quot;&quot;&quot;</span>
    152. <span class="sd"> Internal function for the use of calculate_bbox_iou_matrix and calculate_bbox_iou_elementwise functions</span>
    153. <span class="sd"> DO NOT CALL THIS FUNCTIONS DIRECTLY - use one of the functions mentioned above</span>
    154. <span class="sd"> &quot;&quot;&quot;</span>
    155. <span class="c1"># Intersection area</span>
    156. <span class="n">intersection_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_x2</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_x1</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> \
    157. <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_y1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    158. <span class="c1"># Union Area</span>
    159. <span class="n">w1</span><span class="p">,</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">b1_x2</span> <span class="o">-</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">-</span> <span class="n">b1_y1</span>
    160. <span class="n">w2</span><span class="p">,</span> <span class="n">h2</span> <span class="o">=</span> <span class="n">b2_x2</span> <span class="o">-</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">-</span> <span class="n">b2_y1</span>
    161. <span class="n">union_area</span> <span class="o">=</span> <span class="n">w1</span> <span class="o">*</span> <span class="n">h1</span> <span class="o">+</span> <span class="n">w2</span> <span class="o">*</span> <span class="n">h2</span> <span class="o">-</span> <span class="n">intersection_area</span> <span class="o">+</span> <span class="n">eps</span>
    162. <span class="n">iou</span> <span class="o">=</span> <span class="n">intersection_area</span> <span class="o">/</span> <span class="n">union_area</span> <span class="c1"># iou</span>
    163. <span class="k">if</span> <span class="n">GIoU</span> <span class="ow">or</span> <span class="n">DIoU</span> <span class="ow">or</span> <span class="n">CIoU</span><span class="p">:</span>
    164. <span class="n">cw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_x2</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_x1</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">)</span> <span class="c1"># convex (smallest enclosing box) width</span>
    165. <span class="n">ch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_y1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">)</span> <span class="c1"># convex height</span>
    166. <span class="c1"># Generalized IoU https://arxiv.org/pdf/1902.09630.pdf</span>
    167. <span class="k">if</span> <span class="n">GIoU</span><span class="p">:</span>
    168. <span class="n">c_area</span> <span class="o">=</span> <span class="n">cw</span> <span class="o">*</span> <span class="n">ch</span> <span class="o">+</span> <span class="n">eps</span> <span class="c1"># convex area</span>
    169. <span class="n">iou</span> <span class="o">-=</span> <span class="p">(</span><span class="n">c_area</span> <span class="o">-</span> <span class="n">union_area</span><span class="p">)</span> <span class="o">/</span> <span class="n">c_area</span> <span class="c1"># GIoU</span>
    170. <span class="c1"># Distance or Complete IoU https://arxiv.org/abs/1911.08287v1</span>
    171. <span class="k">if</span> <span class="n">DIoU</span> <span class="ow">or</span> <span class="n">CIoU</span><span class="p">:</span>
    172. <span class="c1"># convex diagonal squared</span>
    173. <span class="n">c2</span> <span class="o">=</span> <span class="n">cw</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">ch</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">eps</span>
    174. <span class="c1"># centerpoint distance squared</span>
    175. <span class="n">rho2</span> <span class="o">=</span> <span class="p">((</span><span class="n">b2_x1</span> <span class="o">+</span> <span class="n">b2_x2</span> <span class="o">-</span> <span class="n">b1_x1</span> <span class="o">-</span> <span class="n">b1_x2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="p">(</span><span class="n">b2_y1</span> <span class="o">+</span> <span class="n">b2_y2</span> <span class="o">-</span> <span class="n">b1_y1</span> <span class="o">-</span> <span class="n">b1_y2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="mi">4</span>
    176. <span class="k">if</span> <span class="n">DIoU</span><span class="p">:</span>
    177. <span class="n">iou</span> <span class="o">-=</span> <span class="n">rho2</span> <span class="o">/</span> <span class="n">c2</span> <span class="c1"># DIoU</span>
    178. <span class="k">elif</span> <span class="n">CIoU</span><span class="p">:</span> <span class="c1"># https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47</span>
    179. <span class="n">v</span> <span class="o">=</span> <span class="p">(</span><span class="mi">4</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">atan</span><span class="p">(</span><span class="n">w2</span> <span class="o">/</span> <span class="n">h2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">atan</span><span class="p">(</span><span class="n">w1</span> <span class="o">/</span> <span class="n">h1</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span>
    180. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
    181. <span class="n">alpha</span> <span class="o">=</span> <span class="n">v</span> <span class="o">/</span> <span class="p">((</span><span class="mi">1</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span> <span class="o">-</span> <span class="n">iou</span> <span class="o">+</span> <span class="n">v</span><span class="p">)</span>
    182. <span class="n">iou</span> <span class="o">-=</span> <span class="p">(</span><span class="n">rho2</span> <span class="o">/</span> <span class="n">c2</span> <span class="o">+</span> <span class="n">v</span> <span class="o">*</span> <span class="n">alpha</span><span class="p">)</span> <span class="c1"># CIoU</span>
    183. <span class="k">return</span> <span class="n">iou</span>
    184. <div class="viewcode-block" id="calculate_bbox_iou_matrix"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.calculate_bbox_iou_matrix">[docs]</a><span class="k">def</span> <span class="nf">calculate_bbox_iou_matrix</span><span class="p">(</span><span class="n">box1</span><span class="p">,</span> <span class="n">box2</span><span class="p">,</span> <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">GIoU</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">DIoU</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">CIoU</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-9</span><span class="p">):</span>
    185. <span class="sd">&quot;&quot;&quot;</span>
    186. <span class="sd"> calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2</span>
    187. <span class="sd"> :param box1: a 2D tensor of boxes (shape N x 4)</span>
    188. <span class="sd"> :param box2: a 2D tensor of boxes (shape M x 4)</span>
    189. <span class="sd"> :param x1y1x2y2: boxes format is x1y1x2y2 (True) or xywh where xy is the center (False)</span>
    190. <span class="sd"> :return: a 2D iou matrix (shape NxM)</span>
    191. <span class="sd"> &quot;&quot;&quot;</span>
    192. <span class="k">if</span> <span class="n">box1</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    193. <span class="n">box1</span> <span class="o">=</span> <span class="n">box1</span><span class="o">.</span><span class="n">T</span>
    194. <span class="c1"># Get the coordinates of bounding boxes</span>
    195. <span class="k">if</span> <span class="n">x1y1x2y2</span><span class="p">:</span> <span class="c1"># x1, y1, x2, y2 = box1</span>
    196. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
    197. <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
    198. <span class="k">else</span><span class="p">:</span> <span class="c1"># x, y, w, h = box1</span>
    199. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_x2</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    200. <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    201. <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_x2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    202. <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
    203. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">b1_x1</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b1_y1</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b1_x2</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b1_y2</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    204. <span class="k">return</span> <span class="n">_iou</span><span class="p">(</span><span class="n">CIoU</span><span class="p">,</span> <span class="n">DIoU</span><span class="p">,</span> <span class="n">GIoU</span><span class="p">,</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span></div>
    205. <div class="viewcode-block" id="calc_bbox_iou_matrix"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.calc_bbox_iou_matrix">[docs]</a><span class="k">def</span> <span class="nf">calc_bbox_iou_matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    206. <span class="sd">&quot;&quot;&quot;</span>
    207. <span class="sd"> calculate iou for every pair of boxes in the boxes vector</span>
    208. <span class="sd"> :param pred: a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where</span>
    209. <span class="sd"> each box format is [x1,y1,x2,y2]</span>
    210. <span class="sd"> :return: a 3-dimensional matrix where M_i_j_k is the iou of box j and box k of the i&#39;th image in the batch</span>
    211. <span class="sd"> &quot;&quot;&quot;</span>
    212. <span class="n">box</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="c1">#</span>
    213. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y1</span> <span class="o">=</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    214. <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    215. <span class="n">b2_x1</span> <span class="o">=</span> <span class="n">b1_x1</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    216. <span class="n">b2_x2</span> <span class="o">=</span> <span class="n">b1_x2</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    217. <span class="n">b2_y1</span> <span class="o">=</span> <span class="n">b1_y1</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    218. <span class="n">b2_y2</span> <span class="o">=</span> <span class="n">b1_y2</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    219. <span class="n">intersection_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_x2</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_x1</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> \
    220. <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_y1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    221. <span class="c1"># Union Area</span>
    222. <span class="n">w1</span><span class="p">,</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">b1_x2</span> <span class="o">-</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">-</span> <span class="n">b1_y1</span>
    223. <span class="n">w2</span><span class="p">,</span> <span class="n">h2</span> <span class="o">=</span> <span class="n">b2_x2</span> <span class="o">-</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">-</span> <span class="n">b2_y1</span>
    224. <span class="n">union_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">w1</span> <span class="o">*</span> <span class="n">h1</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span> <span class="o">+</span> <span class="n">w2</span> <span class="o">*</span> <span class="n">h2</span> <span class="o">-</span> <span class="n">intersection_area</span>
    225. <span class="n">ious</span> <span class="o">=</span> <span class="n">intersection_area</span> <span class="o">/</span> <span class="n">union_area</span>
    226. <span class="k">return</span> <span class="n">ious</span></div>
    227. <div class="viewcode-block" id="change_bbox_bounds_for_image_size"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.change_bbox_bounds_for_image_size">[docs]</a><span class="k">def</span> <span class="nf">change_bbox_bounds_for_image_size</span><span class="p">(</span><span class="n">boxes</span><span class="p">,</span> <span class="n">img_shape</span><span class="p">):</span>
    228. <span class="c1"># CLIP BOUNDING XYXY BOUNDING BOXES TO IMAGE SHAPE (HEIGHT, WIDTH)</span>
    229. <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">img_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    230. <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">img_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    231. <span class="k">return</span> <span class="n">boxes</span></div>
    232. <div class="viewcode-block" id="DetectionPostPredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionPostPredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">DetectionPostPredictionCallback</span><span class="p">(</span><span class="n">ABC</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    233. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
    234. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    235. <div class="viewcode-block" id="DetectionPostPredictionCallback.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionPostPredictionCallback.forward">[docs]</a> <span class="nd">@abstractmethod</span>
    236. <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    237. <span class="sd">&quot;&quot;&quot;</span>
    238. <span class="sd"> :param x: the output of your model</span>
    239. <span class="sd"> :param device: the device to move all output tensors into</span>
    240. <span class="sd"> :return: a list with length batch_size, each item in the list is a detections</span>
    241. <span class="sd"> with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]</span>
    242. <span class="sd"> &quot;&quot;&quot;</span>
    243. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div></div>
    244. <div class="viewcode-block" id="IouThreshold"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.IouThreshold">[docs]</a><span class="k">class</span> <span class="nc">IouThreshold</span><span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
    245. <span class="n">MAP_05</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
    246. <span class="n">MAP_05_TO_095</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">)</span>
    247. <div class="viewcode-block" id="IouThreshold.is_range"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.IouThreshold.is_range">[docs]</a> <span class="k">def</span> <span class="nf">is_range</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    248. <span class="k">return</span> <span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="bp">self</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></div>
    249. <div class="viewcode-block" id="IouThreshold.to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.IouThreshold.to_tensor">[docs]</a> <span class="k">def</span> <span class="nf">to_tensor</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    250. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_range</span><span class="p">():</span>
    251. <span class="n">n_iou_thresh</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="bp">self</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">/</span> <span class="mf">0.05</span><span class="p">))</span> <span class="o">+</span> <span class="mi">1</span>
    252. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">n_iou_thresh</span><span class="p">)</span>
    253. <span class="k">else</span><span class="p">:</span>
    254. <span class="n">n_iou_thresh</span> <span class="o">=</span> <span class="mi">1</span>
    255. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></div></div>
    256. <div class="viewcode-block" id="box_iou"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.box_iou">[docs]</a><span class="k">def</span> <span class="nf">box_iou</span><span class="p">(</span><span class="n">box1</span><span class="p">,</span> <span class="n">box2</span><span class="p">):</span>
    257. <span class="c1"># https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py</span>
    258. <span class="sd">&quot;&quot;&quot;</span>
    259. <span class="sd"> Return intersection-over-union (Jaccard index) of boxes.</span>
    260. <span class="sd"> Both sets of boxes are expected to be in (x1, y1, x2, y2) format.</span>
    261. <span class="sd"> Arguments:</span>
    262. <span class="sd"> box1 (Tensor[N, 4])</span>
    263. <span class="sd"> box2 (Tensor[M, 4])</span>
    264. <span class="sd"> Returns:</span>
    265. <span class="sd"> iou (Tensor[N, M]): the NxM matrix containing the pairwise</span>
    266. <span class="sd"> IoU values for every element in boxes1 and boxes2</span>
    267. <span class="sd"> &quot;&quot;&quot;</span>
    268. <span class="k">def</span> <span class="nf">box_area</span><span class="p">(</span><span class="n">box</span><span class="p">):</span>
    269. <span class="c1"># box = 4xn</span>
    270. <span class="k">return</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">*</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    271. <span class="n">area1</span> <span class="o">=</span> <span class="n">box_area</span><span class="p">(</span><span class="n">box1</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
    272. <span class="n">area2</span> <span class="o">=</span> <span class="n">box_area</span><span class="p">(</span><span class="n">box2</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
    273. <span class="c1"># inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)</span>
    274. <span class="n">inter</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">box1</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">2</span><span class="p">:],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:])</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">box1</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
    275. <span class="k">return</span> <span class="n">inter</span> <span class="o">/</span> <span class="p">(</span><span class="n">area1</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">area2</span> <span class="o">-</span> <span class="n">inter</span><span class="p">)</span> <span class="c1"># iou = inter / (area1 + area2 - inter)</span></div>
    276. <div class="viewcode-block" id="non_max_suppression"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.non_max_suppression">[docs]</a><span class="k">def</span> <span class="nf">non_max_suppression</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">conf_thres</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">iou_thres</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span>
    277. <span class="n">multi_label_per_box</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">with_confidence</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    278. <span class="sd">&quot;&quot;&quot;</span>
    279. <span class="sd"> Performs Non-Maximum Suppression (NMS) on inference results</span>
    280. <span class="sd"> :param prediction: raw model prediction</span>
    281. <span class="sd"> :param conf_thres: below the confidence threshold - prediction are discarded</span>
    282. <span class="sd"> :param iou_thres: IoU threshold for the nms algorithm</span>
    283. <span class="sd"> :param multi_label_per_box: whether to use re-use each box with all possible labels</span>
    284. <span class="sd"> (instead of the maximum confidence all confidences above threshold</span>
    285. <span class="sd"> will be sent to NMS); by default is set to True</span>
    286. <span class="sd"> :param with_confidence: whether to multiply objectness score with class score.</span>
    287. <span class="sd"> usually valid for Yolo models only.</span>
    288. <span class="sd"> :return: (x1, y1, x2, y2, object_conf, class_conf, class)</span>
    289. <span class="sd"> Returns:</span>
    290. <span class="sd"> detections with shape: nx6 (x1, y1, x2, y2, conf, cls)</span>
    291. <span class="sd"> &quot;&quot;&quot;</span>
    292. <span class="n">candidates_above_thres</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">conf_thres</span> <span class="c1"># filter by confidence</span>
    293. <span class="n">output</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">prediction</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    294. <span class="k">for</span> <span class="n">image_idx</span><span class="p">,</span> <span class="n">pred</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">prediction</span><span class="p">):</span>
    295. <span class="n">pred</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[</span><span class="n">candidates_above_thres</span><span class="p">[</span><span class="n">image_idx</span><span class="p">]]</span> <span class="c1"># confident</span>
    296. <span class="k">if</span> <span class="ow">not</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span> <span class="c1"># If none remain process next image</span>
    297. <span class="k">continue</span>
    298. <span class="k">if</span> <span class="n">with_confidence</span><span class="p">:</span>
    299. <span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:]</span> <span class="o">*=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span> <span class="c1"># multiply objectness score with class score</span>
    300. <span class="n">box</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">])</span> <span class="c1"># xywh to xyxy</span>
    301. <span class="c1"># Detections matrix nx6 (xyxy, conf, cls)</span>
    302. <span class="k">if</span> <span class="n">multi_label_per_box</span><span class="p">:</span> <span class="c1"># try for all good confidence classes</span>
    303. <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="o">=</span> <span class="p">(</span><span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:]</span> <span class="o">&gt;</span> <span class="n">conf_thres</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
    304. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">box</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">5</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">j</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">.</span><span class="n">float</span><span class="p">()),</span> <span class="mi">1</span><span class="p">)</span>
    305. <span class="k">else</span><span class="p">:</span> <span class="c1"># best class only</span>
    306. <span class="n">conf</span><span class="p">,</span> <span class="n">j</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:]</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    307. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">box</span><span class="p">,</span> <span class="n">conf</span><span class="p">,</span> <span class="n">j</span><span class="o">.</span><span class="n">float</span><span class="p">()),</span> <span class="mi">1</span><span class="p">)[</span><span class="n">conf</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">conf_thres</span><span class="p">]</span>
    308. <span class="k">if</span> <span class="ow">not</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span> <span class="c1"># If none remain process next image</span>
    309. <span class="k">continue</span>
    310. <span class="c1"># Apply torch batched NMS algorithm</span>
    311. <span class="n">boxes</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">cls_idx</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span>
    312. <span class="n">idx_to_keep</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">boxes</span><span class="o">.</span><span class="n">batched_nms</span><span class="p">(</span><span class="n">boxes</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">cls_idx</span><span class="p">,</span> <span class="n">iou_thres</span><span class="p">)</span>
    313. <span class="n">output</span><span class="p">[</span><span class="n">image_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[</span><span class="n">idx_to_keep</span><span class="p">]</span>
    314. <span class="k">return</span> <span class="n">output</span></div>
    315. <div class="viewcode-block" id="matrix_non_max_suppression"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.matrix_non_max_suppression">[docs]</a><span class="k">def</span> <span class="nf">matrix_non_max_suppression</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">conf_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">kernel</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;gaussian&#39;</span><span class="p">,</span>
    316. <span class="n">sigma</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3.0</span><span class="p">,</span> <span class="n">max_num_of_detections</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">500</span><span class="p">):</span>
    317. <span class="sd">&quot;&quot;&quot;Performs Matrix Non-Maximum Suppression (NMS) on inference results</span>
    318. <span class="sd"> https://arxiv.org/pdf/1912.04488.pdf</span>
    319. <span class="sd"> :param pred: raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85]</span>
    320. <span class="sd"> where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)</span>
    321. <span class="sd"> :param conf_thres: below the confidence threshold - prediction are discarded</span>
    322. <span class="sd"> :param kernel: type of kernel to use [&#39;gaussian&#39;, &#39;linear&#39;]</span>
    323. <span class="sd"> :param sigma: sigma for the gussian kernel</span>
    324. <span class="sd"> :param max_num_of_detections: maximum number of boxes to output</span>
    325. <span class="sd"> :return: list of (x1, y1, x2, y2, object_conf, class_conf, class)</span>
    326. <span class="sd"> Returns:</span>
    327. <span class="sd"> detections list with shape: (x1, y1, x2, y2, conf, cls)</span>
    328. <span class="sd"> &quot;&quot;&quot;</span>
    329. <span class="c1"># MULTIPLY CONF BY CLASS CONF TO GET COMBINED CONFIDENCE</span>
    330. <span class="n">class_conf</span><span class="p">,</span> <span class="n">class_pred</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">5</span><span class="p">:]</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
    331. <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">class_conf</span>
    332. <span class="c1"># BOX (CENTER X, CENTER Y, WIDTH, HEIGHT) TO (X1, Y1, X2, Y2)</span>
    333. <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">])</span>
    334. <span class="c1"># DETECTIONS ORDERED AS (x1y1x2y2, obj_conf, class_conf, class_pred)</span>
    335. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">5</span><span class="p">],</span> <span class="n">class_pred</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">)),</span> <span class="mi">2</span><span class="p">)</span>
    336. <span class="c1"># SORT DETECTIONS BY DECREASING CONFIDENCE SCORES</span>
    337. <span class="n">sort_ind</span> <span class="o">=</span> <span class="p">(</span><span class="o">-</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">])</span><span class="o">.</span><span class="n">argsort</span><span class="p">()</span>
    338. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">sort_ind</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])])[:,</span> <span class="mi">0</span><span class="p">:</span><span class="n">max_num_of_detections</span><span class="p">]</span>
    339. <span class="n">ious</span> <span class="o">=</span> <span class="n">calc_bbox_iou_matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>
    340. <span class="n">ious</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    341. <span class="c1"># CREATE A LABELS MASK, WE WANT ONLY BOXES WITH THE SAME LABEL TO AFFECT EACH OTHER</span>
    342. <span class="n">labels</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">5</span><span class="p">:]</span>
    343. <span class="n">labeles_matrix</span> <span class="o">=</span> <span class="p">(</span><span class="n">labels</span> <span class="o">==</span> <span class="n">labels</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    344. <span class="n">ious</span> <span class="o">*=</span> <span class="n">labeles_matrix</span>
    345. <span class="n">ious_cmax</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    346. <span class="n">ious_cmax</span> <span class="o">=</span> <span class="n">ious_cmax</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">max_num_of_detections</span><span class="p">)</span>
    347. <span class="k">if</span> <span class="n">kernel</span> <span class="o">==</span> <span class="s1">&#39;gaussian&#39;</span><span class="p">:</span>
    348. <span class="n">decay_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">*</span> <span class="p">(</span><span class="n">ious</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
    349. <span class="n">compensate_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">*</span> <span class="p">(</span><span class="n">ious_cmax</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
    350. <span class="n">decay</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="p">(</span><span class="n">decay_matrix</span> <span class="o">/</span> <span class="n">compensate_matrix</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    351. <span class="k">else</span><span class="p">:</span>
    352. <span class="n">decay</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">ious</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">ious_cmax</span><span class="p">)</span>
    353. <span class="n">decay</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">decay</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    354. <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">decay</span>
    355. <span class="n">output</span> <span class="o">=</span> <span class="p">[</span><span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">conf_thres</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
    356. <span class="k">return</span> <span class="n">output</span></div>
    357. <div class="viewcode-block" id="NMS_Type"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.NMS_Type">[docs]</a><span class="k">class</span> <span class="nc">NMS_Type</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
    358. <span class="sd">&quot;&quot;&quot;</span>
    359. <span class="sd"> Type of non max suppression algorithm that can be used for post processing detection</span>
    360. <span class="sd"> &quot;&quot;&quot;</span>
    361. <span class="n">ITERATIVE</span> <span class="o">=</span> <span class="s1">&#39;iterative&#39;</span>
    362. <span class="n">MATRIX</span> <span class="o">=</span> <span class="s1">&#39;matrix&#39;</span></div>
    363. <div class="viewcode-block" id="undo_image_preprocessing"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.undo_image_preprocessing">[docs]</a><span class="k">def</span> <span class="nf">undo_image_preprocessing</span><span class="p">(</span><span class="n">im_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
    364. <span class="sd">&quot;&quot;&quot;</span>
    365. <span class="sd"> :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)</span>
    366. <span class="sd"> :return: images in a batch in cv2 format, BGR, (B, H, W, C)</span>
    367. <span class="sd"> &quot;&quot;&quot;</span>
    368. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_tensor</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    369. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_np</span><span class="p">[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    370. <span class="n">im_np</span> <span class="o">*=</span> <span class="mf">255.</span>
    371. <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">im_np</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span></div>
    372. <div class="viewcode-block" id="DetectionVisualization"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionVisualization">[docs]</a><span class="k">class</span> <span class="nc">DetectionVisualization</span><span class="p">:</span>
    373. <span class="nd">@staticmethod</span>
    374. <span class="k">def</span> <span class="nf">_generate_color_mapping</span><span class="p">(</span><span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]:</span>
    375. <span class="sd">&quot;&quot;&quot;</span>
    376. <span class="sd"> Generate a unique BGR color for each class</span>
    377. <span class="sd"> &quot;&quot;&quot;</span>
    378. <span class="n">cmap</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">get_cmap</span><span class="p">(</span><span class="s1">&#39;gist_rainbow&#39;</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
    379. <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="n">cmap</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">bytes</span><span class="o">=</span><span class="kc">True</span><span class="p">)[:</span><span class="mi">3</span><span class="p">][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)]</span>
    380. <span class="k">return</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">c</span><span class="p">)</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">colors</span><span class="p">]</span>
    381. <span class="nd">@staticmethod</span>
    382. <span class="k">def</span> <span class="nf">_draw_box_title</span><span class="p">(</span><span class="n">color_mapping</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span> <span class="n">class_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">box_thickness</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    383. <span class="n">image_np</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">x1</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y1</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">x2</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y2</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">class_id</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    384. <span class="n">pred_conf</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">is_target</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    385. <span class="n">color</span> <span class="o">=</span> <span class="n">color_mapping</span><span class="p">[</span><span class="n">class_id</span><span class="p">]</span>
    386. <span class="n">class_name</span> <span class="o">=</span> <span class="n">class_names</span><span class="p">[</span><span class="n">class_id</span><span class="p">]</span>
    387. <span class="c1"># Draw the box</span>
    388. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">rectangle</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">),</span> <span class="p">(</span><span class="n">x2</span><span class="p">,</span> <span class="n">y2</span><span class="p">),</span> <span class="n">color</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">)</span>
    389. <span class="c1"># Caption with class name and confidence if given</span>
    390. <span class="n">text_color</span> <span class="o">=</span> <span class="p">(</span><span class="mi">255</span><span class="p">,</span> <span class="mi">255</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span> <span class="c1"># white</span>
    391. <span class="k">if</span> <span class="n">is_target</span><span class="p">:</span>
    392. <span class="n">title</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[GT] </span><span class="si">{</span><span class="n">class_name</span><span class="si">}</span><span class="s1">&#39;</span>
    393. <span class="k">if</span> <span class="ow">not</span> <span class="n">is_target</span><span class="p">:</span>
    394. <span class="n">title</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[Pred] </span><span class="si">{</span><span class="n">class_name</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">pred_conf</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="k">if</span> <span class="n">pred_conf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="s2">&quot;&quot;</span><span class="si">}</span><span class="s1">&#39;</span>
    395. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">rectangle</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span> <span class="o">-</span> <span class="mi">15</span><span class="p">),</span> <span class="p">(</span><span class="n">x1</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">title</span><span class="p">)</span> <span class="o">*</span> <span class="mi">10</span><span class="p">,</span> <span class="n">y1</span><span class="p">),</span> <span class="n">color</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">FILLED</span><span class="p">)</span>
    396. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">putText</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span> <span class="o">-</span> <span class="n">box_thickness</span><span class="p">),</span> <span class="mi">2</span><span class="p">,</span> <span class="mf">.5</span><span class="p">,</span> <span class="n">text_color</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">lineType</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">LINE_AA</span><span class="p">)</span>
    397. <span class="k">return</span> <span class="n">image_np</span>
    398. <span class="nd">@staticmethod</span>
    399. <span class="k">def</span> <span class="nf">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">pred_boxes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
    400. <span class="n">class_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">box_thickness</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
    401. <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">image_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    402. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">fx</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span> <span class="n">fy</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_NEAREST</span><span class="p">)</span>
    403. <span class="n">color_mapping</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_generate_color_mapping</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">class_names</span><span class="p">))</span>
    404. <span class="c1"># Draw predictions</span>
    405. <span class="n">pred_boxes</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">image_scale</span>
    406. <span class="k">for</span> <span class="n">box</span> <span class="ow">in</span> <span class="n">pred_boxes</span><span class="p">:</span>
    407. <span class="n">image_np</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_draw_box_title</span><span class="p">(</span><span class="n">color_mapping</span><span class="p">,</span> <span class="n">class_names</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">,</span>
    408. <span class="n">image_np</span><span class="p">,</span> <span class="o">*</span><span class="n">box</span><span class="p">[:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">),</span>
    409. <span class="n">class_id</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">5</span><span class="p">]),</span> <span class="n">pred_conf</span><span class="o">=</span><span class="n">box</span><span class="p">[</span><span class="mi">4</span><span class="p">])</span>
    410. <span class="c1"># Draw ground truths</span>
    411. <span class="n">target_boxes_image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
    412. <span class="k">for</span> <span class="n">box</span> <span class="ow">in</span> <span class="n">target_boxes</span><span class="p">:</span>
    413. <span class="n">target_boxes_image</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_draw_box_title</span><span class="p">(</span><span class="n">color_mapping</span><span class="p">,</span> <span class="n">class_names</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">,</span>
    414. <span class="n">target_boxes_image</span><span class="p">,</span> <span class="o">*</span><span class="n">box</span><span class="p">[</span><span class="mi">2</span><span class="p">:],</span>
    415. <span class="n">class_id</span><span class="o">=</span><span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">is_target</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    416. <span class="c1"># Transparent overlay of ground truth boxes</span>
    417. <span class="n">mask</span> <span class="o">=</span> <span class="n">target_boxes_image</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">bool</span><span class="p">)</span>
    418. <span class="n">image_np</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">addWeighted</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">gt_alpha</span><span class="p">,</span> <span class="n">target_boxes_image</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">,</span> <span class="mi">0</span><span class="p">)[</span><span class="n">mask</span><span class="p">]</span>
    419. <span class="k">if</span> <span class="n">checkpoint_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    420. <span class="k">return</span> <span class="n">image_np</span>
    421. <span class="k">else</span><span class="p">:</span>
    422. <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">)</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    423. <span class="n">cv2</span><span class="o">.</span><span class="n">imwrite</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">image_name</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;.jpg&#39;</span><span class="p">),</span> <span class="n">image_np</span><span class="p">)</span>
    424. <span class="nd">@staticmethod</span>
    425. <span class="k">def</span> <span class="nf">_scaled_ccwh_to_xyxy</span><span class="p">(</span><span class="n">target_boxes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">w</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
    426. <span class="sd">&quot;&quot;&quot;</span>
    427. <span class="sd"> Modifies target_boxes inplace</span>
    428. <span class="sd"> :param target_boxes: (c1, c2, w, h) boxes in [0, 1] range</span>
    429. <span class="sd"> :param h: image height</span>
    430. <span class="sd"> :param w: image width</span>
    431. <span class="sd"> :param image_scale: desired scale for the boxes w.r.t. w and h</span>
    432. <span class="sd"> :return: targets in (x1, y1, x2, y2) format</span>
    433. <span class="sd"> in range [0, w * self.image_scale] [0, h * self.image_scale]</span>
    434. <span class="sd"> &quot;&quot;&quot;</span>
    435. <span class="c1"># unscale</span>
    436. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">*=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">]])</span>
    437. <span class="c1"># x1 = c1 - w // 2; y1 = c2 - h // 2</span>
    438. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">-=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
    439. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">-=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
    440. <span class="c1"># x2 = w + x1; y2 = h + y1</span>
    441. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">+=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span>
    442. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span> <span class="o">+=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
    443. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">*=</span> <span class="n">image_scale</span>
    444. <span class="n">target_boxes</span> <span class="o">=</span> <span class="n">target_boxes</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
    445. <span class="k">return</span> <span class="n">target_boxes</span>
    446. <div class="viewcode-block" id="DetectionVisualization.visualize_batch"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionVisualization.visualize_batch">[docs]</a> <span class="nd">@staticmethod</span>
    447. <span class="k">def</span> <span class="nf">visualize_batch</span><span class="p">(</span><span class="n">image_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">pred_boxes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">target_boxes</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    448. <span class="n">batch_name</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span> <span class="n">class_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    449. <span class="n">undo_preprocessing_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="n">undo_image_preprocessing</span><span class="p">,</span>
    450. <span class="n">box_thickness</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">.4</span><span class="p">):</span>
    451. <span class="sd">&quot;&quot;&quot;</span>
    452. <span class="sd"> A helper function to visualize detections predicted by a network:</span>
    453. <span class="sd"> saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.</span>
    454. <span class="sd"> Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.</span>
    455. <span class="sd"> Adjustable:</span>
    456. <span class="sd"> * Ground truth box transparency;</span>
    457. <span class="sd"> * Box width;</span>
    458. <span class="sd"> * Image size (larger or smaller than what&#39;s provided)</span>
    459. <span class="sd"> :param image_tensor: rgb images, (B, H, W, 3)</span>
    460. <span class="sd"> :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),</span>
    461. <span class="sd"> values on dim 1 are: x1, y1, x2, y2, confidence, class</span>
    462. <span class="sd"> :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h</span>
    463. <span class="sd"> (coordinates scaled to [0, 1])</span>
    464. <span class="sd"> :param batch_name: id of the current batch to use for image naming</span>
    465. <span class="sd"> :param class_names: names of all classes, each on its own index</span>
    466. <span class="sd"> :param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will</span>
    467. <span class="sd"> be returns as a list of numpy image arrays</span>
    468. <span class="sd"> :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images</span>
    469. <span class="sd"> :param box_thickness: box line thickness in px</span>
    470. <span class="sd"> :param image_scale: scale of an image w.r.t. given image size,</span>
    471. <span class="sd"> e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)</span>
    472. <span class="sd"> :param gt_alpha: a value in [0., 1.] transparency on ground truth boxes,</span>
    473. <span class="sd"> 0 for invisible, 1 for fully opaque</span>
    474. <span class="sd"> &quot;&quot;&quot;</span>
    475. <span class="n">image_np</span> <span class="o">=</span> <span class="n">undo_preprocessing_func</span><span class="p">(</span><span class="n">image_tensor</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
    476. <span class="n">targets</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_scaled_ccwh_to_xyxy</span><span class="p">(</span><span class="n">target_boxes</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="o">*</span><span class="n">image_np</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">],</span>
    477. <span class="n">image_scale</span><span class="p">)</span>
    478. <span class="n">out_images</span> <span class="o">=</span> <span class="p">[]</span>
    479. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">image_np</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
    480. <span class="n">preds</span> <span class="o">=</span> <span class="n">pred_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">if</span> <span class="n">pred_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
    481. <span class="n">targets_cur</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">i</span><span class="p">]</span>
    482. <span class="n">image_name</span> <span class="o">=</span> <span class="s1">&#39;_&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="nb">str</span><span class="p">(</span><span class="n">batch_name</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)])</span>
    483. <span class="n">res_image</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">preds</span><span class="p">,</span> <span class="n">targets_cur</span><span class="p">,</span> <span class="n">class_names</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">,</span>
    484. <span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">image_name</span><span class="p">)</span>
    485. <span class="k">if</span> <span class="n">res_image</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    486. <span class="n">out_images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">res_image</span><span class="p">)</span>
    487. <span class="k">return</span> <span class="n">out_images</span></div></div>
    488. <div class="viewcode-block" id="Anchors"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.Anchors">[docs]</a><span class="k">class</span> <span class="nc">Anchors</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    489. <span class="sd">&quot;&quot;&quot;</span>
    490. <span class="sd"> A wrapper function to hold the anchors used by detection models such as Yolo</span>
    491. <span class="sd"> &quot;&quot;&quot;</span>
    492. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">anchors_list</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">],</span> <span class="n">strides</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span>
    493. <span class="sd">&quot;&quot;&quot;</span>
    494. <span class="sd"> :param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds</span>
    495. <span class="sd"> the width and height of the anchors of a specific detection layer.</span>
    496. <span class="sd"> i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each</span>
    497. <span class="sd"> The width and height are in pixels (not relative to image size)</span>
    498. <span class="sd"> :param strides: a list containing the stride of the layers from which the detection heads are fed.</span>
    499. <span class="sd"> i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8</span>
    500. <span class="sd"> &quot;&quot;&quot;</span>
    501. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    502. <span class="bp">self</span><span class="o">.</span><span class="n">__anchors_list</span> <span class="o">=</span> <span class="n">anchors_list</span>
    503. <span class="bp">self</span><span class="o">.</span><span class="n">__strides</span> <span class="o">=</span> <span class="n">strides</span>
    504. <span class="bp">self</span><span class="o">.</span><span class="n">_check_all_lists</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">)</span>
    505. <span class="bp">self</span><span class="o">.</span><span class="n">_check_all_len_equal_and_even</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">)</span>
    506. <span class="bp">self</span><span class="o">.</span><span class="n">_stride</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">strides</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    507. <span class="n">anchors</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    508. <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">anchors</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_stride</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    509. <span class="bp">self</span><span class="o">.</span><span class="n">_anchor_grid</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">anchors</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    510. <span class="nd">@staticmethod</span>
    511. <span class="k">def</span> <span class="nf">_check_all_lists</span><span class="p">(</span><span class="n">anchors</span><span class="p">:</span> <span class="nb">list</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
    512. <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">anchors</span><span class="p">:</span>
    513. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">ListConfig</span><span class="p">)):</span>
    514. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;All objects of anchors_list must be lists&#39;</span><span class="p">)</span>
    515. <span class="nd">@staticmethod</span>
    516. <span class="k">def</span> <span class="nf">_check_all_len_equal_and_even</span><span class="p">(</span><span class="n">anchors</span><span class="p">:</span> <span class="nb">list</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
    517. <span class="n">len_of_first</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">anchors</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    518. <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">anchors</span><span class="p">:</span>
    519. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">!=</span> <span class="n">len_of_first</span><span class="p">:</span>
    520. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;All objects of anchors_list must be of the same even length&#39;</span><span class="p">)</span>
    521. <span class="nd">@property</span>
    522. <span class="k">def</span> <span class="nf">stride</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">:</span>
    523. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_stride</span>
    524. <span class="nd">@property</span>
    525. <span class="k">def</span> <span class="nf">anchors</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">:</span>
    526. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span>
    527. <span class="nd">@property</span>
    528. <span class="k">def</span> <span class="nf">anchor_grid</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">:</span>
    529. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchor_grid</span>
    530. <span class="nd">@property</span>
    531. <span class="k">def</span> <span class="nf">detection_layers_num</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
    532. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    533. <span class="nd">@property</span>
    534. <span class="k">def</span> <span class="nf">num_anchors</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
    535. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    536. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    537. <span class="k">return</span> <span class="sa">f</span><span class="s2">&quot;anchors_list: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">__anchors_list</span><span class="si">}</span><span class="s2"> strides: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">__strides</span><span class="si">}</span><span class="s2">&quot;</span></div>
    538. <div class="viewcode-block" id="xyxy2cxcywh"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.xyxy2cxcywh">[docs]</a><span class="k">def</span> <span class="nf">xyxy2cxcywh</span><span class="p">(</span><span class="n">bboxes</span><span class="p">):</span>
    539. <span class="sd">&quot;&quot;&quot;</span>
    540. <span class="sd"> Transforms bboxes from xyxy format to centerized xy wh format</span>
    541. <span class="sd"> :param bboxes: array, shaped (nboxes, 4)</span>
    542. <span class="sd"> :return: modified bboxes</span>
    543. <span class="sd"> &quot;&quot;&quot;</span>
    544. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
    545. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
    546. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
    547. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
    548. <span class="k">return</span> <span class="n">bboxes</span></div>
    549. <div class="viewcode-block" id="cxcywh2xyxy"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.cxcywh2xyxy">[docs]</a><span class="k">def</span> <span class="nf">cxcywh2xyxy</span><span class="p">(</span><span class="n">bboxes</span><span class="p">):</span>
    550. <span class="sd">&quot;&quot;&quot;</span>
    551. <span class="sd"> Transforms bboxes from centerized xy wh format to xyxy format</span>
    552. <span class="sd"> :param bboxes: array, shaped (nboxes, 4)</span>
    553. <span class="sd"> :return: modified bboxes</span>
    554. <span class="sd"> &quot;&quot;&quot;</span>
    555. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
    556. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
    557. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
    558. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
    559. <span class="k">return</span> <span class="n">bboxes</span></div>
    560. <div class="viewcode-block" id="get_mosaic_coordinate"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.get_mosaic_coordinate">[docs]</a><span class="k">def</span> <span class="nf">get_mosaic_coordinate</span><span class="p">(</span><span class="n">mosaic_index</span><span class="p">,</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">input_w</span><span class="p">):</span>
    561. <span class="sd">&quot;&quot;&quot;</span>
    562. <span class="sd"> Returns the mosaic coordinates of final mosaic image according to mosaic image index.</span>
    563. <span class="sd"> :param mosaic_index: (int) mosaic image index</span>
    564. <span class="sd"> :param xc: (int) center x coordinate of the entire mosaic grid.</span>
    565. <span class="sd"> :param yc: (int) center y coordinate of the entire mosaic grid.</span>
    566. <span class="sd"> :param w: (int) width of bbox</span>
    567. <span class="sd"> :param h: (int) height of bbox</span>
    568. <span class="sd"> :param input_h: (int) image input height (should be 1/2 of the final mosaic output image height).</span>
    569. <span class="sd"> :param input_w: (int) image input width (should be 1/2 of the final mosaic output image width).</span>
    570. <span class="sd"> :return: (x1, y1, x2, y2), (x1s, y1s, x2s, y2s) where (x1, y1, x2, y2) are the coordinates in the final mosaic</span>
    571. <span class="sd"> output image, and (x1s, y1s, x2s, y2s) are the coordinates in the placed image.</span>
    572. <span class="sd"> &quot;&quot;&quot;</span>
    573. <span class="c1"># index0 to top left part of image</span>
    574. <span class="k">if</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    575. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">xc</span> <span class="o">-</span> <span class="n">w</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">yc</span> <span class="o">-</span> <span class="n">h</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span>
    576. <span class="n">small_coord</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="p">(</span><span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="n">h</span> <span class="o">-</span> <span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">),</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span>
    577. <span class="c1"># index1 to top right part of image</span>
    578. <span class="k">elif</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
    579. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="n">xc</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="n">yc</span> <span class="o">-</span> <span class="n">h</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">xc</span> <span class="o">+</span> <span class="n">w</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">*</span> <span class="mi">2</span><span class="p">),</span> <span class="n">yc</span>
    580. <span class="n">small_coord</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">h</span> <span class="o">-</span> <span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="n">h</span>
    581. <span class="c1"># index2 to bottom left part of image</span>
    582. <span class="k">elif</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
    583. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">xc</span> <span class="o">-</span> <span class="n">w</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">yc</span><span class="p">,</span> <span class="n">xc</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">input_h</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">yc</span> <span class="o">+</span> <span class="n">h</span><span class="p">)</span>
    584. <span class="n">small_coord</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="p">(</span><span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
    585. <span class="c1"># index2 to bottom right part of image</span>
    586. <span class="k">elif</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
    587. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">xc</span> <span class="o">+</span> <span class="n">w</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">*</span> <span class="mi">2</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">input_h</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">yc</span> <span class="o">+</span> <span class="n">h</span><span class="p">)</span> <span class="c1"># noqa</span>
    588. <span class="n">small_coord</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
    589. <span class="k">return</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span><span class="p">),</span> <span class="n">small_coord</span></div>
    590. <div class="viewcode-block" id="adjust_box_anns"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.adjust_box_anns">[docs]</a><span class="k">def</span> <span class="nf">adjust_box_anns</span><span class="p">(</span><span class="n">bbox</span><span class="p">,</span> <span class="n">scale_ratio</span><span class="p">,</span> <span class="n">padw</span><span class="p">,</span> <span class="n">padh</span><span class="p">,</span> <span class="n">w_max</span><span class="p">,</span> <span class="n">h_max</span><span class="p">):</span>
    591. <span class="sd">&quot;&quot;&quot;</span>
    592. <span class="sd"> Adjusts the bbox annotations of rescaled, padded image.</span>
    593. <span class="sd"> :param bbox: (np.array) bbox to modify.</span>
    594. <span class="sd"> :param scale_ratio: (float) scale ratio between rescale output image and original one.</span>
    595. <span class="sd"> :param padw: (int) width padding size.</span>
    596. <span class="sd"> :param padh: (int) height padding size.</span>
    597. <span class="sd"> :param w_max: (int) width border.</span>
    598. <span class="sd"> :param h_max: (int) height border</span>
    599. <span class="sd"> :return: modified bbox (np.array)</span>
    600. <span class="sd"> &quot;&quot;&quot;</span>
    601. <span class="n">bbox</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">bbox</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_ratio</span> <span class="o">+</span> <span class="n">padw</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">w_max</span><span class="p">)</span>
    602. <span class="n">bbox</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">bbox</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_ratio</span> <span class="o">+</span> <span class="n">padh</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">h_max</span><span class="p">)</span>
    603. <span class="k">return</span> <span class="n">bbox</span></div>
    604. <div class="viewcode-block" id="DetectionCollateFN"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionCollateFN">[docs]</a><span class="k">class</span> <span class="nc">DetectionCollateFN</span><span class="p">:</span>
    605. <span class="sd">&quot;&quot;&quot;</span>
    606. <span class="sd"> Collate function for Yolox training</span>
    607. <span class="sd"> &quot;&quot;&quot;</span>
    608. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
    609. <span class="n">batch</span> <span class="o">=</span> <span class="n">default_collate</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    610. <span class="n">ims</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span>
    611. <span class="k">return</span> <span class="n">ims</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_targets</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
    612. <span class="k">def</span> <span class="nf">_format_targets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    613. <span class="n">nlabel</span> <span class="o">=</span> <span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># number of label per image</span>
    614. <span class="n">targets_merged</span> <span class="o">=</span> <span class="p">[]</span>
    615. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
    616. <span class="n">targets_im</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:</span><span class="n">nlabel</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
    617. <span class="n">batch_column</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">new_ones</span><span class="p">((</span><span class="n">targets_im</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="n">i</span>
    618. <span class="n">targets_merged</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">batch_column</span><span class="p">,</span> <span class="n">targets_im</span><span class="p">),</span> <span class="mi">1</span><span class="p">))</span>
    619. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">targets_merged</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></div>
    620. <div class="viewcode-block" id="CrowdDetectionCollateFN"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN">[docs]</a><span class="k">class</span> <span class="nc">CrowdDetectionCollateFN</span><span class="p">(</span><span class="n">DetectionCollateFN</span><span class="p">):</span>
    621. <span class="sd">&quot;&quot;&quot;</span>
    622. <span class="sd"> Collate function for Yolox training with additional_batch_items that includes crowd targets</span>
    623. <span class="sd"> &quot;&quot;&quot;</span>
    624. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]:</span>
    625. <span class="n">batch</span> <span class="o">=</span> <span class="n">default_collate</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    626. <span class="n">ims</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span>
    627. <span class="k">return</span> <span class="n">ims</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_targets</span><span class="p">(</span><span class="n">targets</span><span class="p">),</span> <span class="p">{</span><span class="s2">&quot;crowd_targets&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_targets</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)}</span></div>
    628. <div class="viewcode-block" id="compute_box_area"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_box_area">[docs]</a><span class="k">def</span> <span class="nf">compute_box_area</span><span class="p">(</span><span class="n">box</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    629. <span class="sd">&quot;&quot;&quot;Compute the area of one or many boxes.</span>
    630. <span class="sd"> :param box: One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)</span>
    631. <span class="sd"> Returns:</span>
    632. <span class="sd"> Area of every box, shape = (1, ?)</span>
    633. <span class="sd"> &quot;&quot;&quot;</span>
    634. <span class="c1"># box = 4xn</span>
    635. <span class="k">return</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">*</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span></div>
    636. <div class="viewcode-block" id="crowd_ioa"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.crowd_ioa">[docs]</a><span class="k">def</span> <span class="nf">crowd_ioa</span><span class="p">(</span><span class="n">det_box</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">crowd_box</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    637. <span class="sd">&quot;&quot;&quot;</span>
    638. <span class="sd"> Return intersection-over-detection_area of boxes, used for crowd ground truths.</span>
    639. <span class="sd"> Both sets of boxes are expected to be in (x1, y1, x2, y2) format.</span>
    640. <span class="sd"> Arguments:</span>
    641. <span class="sd"> det_box (Tensor[N, 4])</span>
    642. <span class="sd"> crowd_box (Tensor[M, 4])</span>
    643. <span class="sd"> Returns:</span>
    644. <span class="sd"> crowd_ioa (Tensor[N, M]): the NxM matrix containing the pairwise</span>
    645. <span class="sd"> IoA values for every element in det_box and crowd_box</span>
    646. <span class="sd"> &quot;&quot;&quot;</span>
    647. <span class="n">det_area</span> <span class="o">=</span> <span class="n">compute_box_area</span><span class="p">(</span><span class="n">det_box</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
    648. <span class="c1"># inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)</span>
    649. <span class="n">inter</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">det_box</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">2</span><span class="p">:],</span> <span class="n">crowd_box</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:])</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">det_box</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">crowd_box</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]))</span> \
    650. <span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
    651. <span class="k">return</span> <span class="n">inter</span> <span class="o">/</span> <span class="n">det_area</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="c1"># crowd_ioa = inter / det_area</span></div>
    652. <div class="viewcode-block" id="compute_detection_matching"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_detection_matching">[docs]</a><span class="k">def</span> <span class="nf">compute_detection_matching</span><span class="p">(</span>
    653. <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    654. <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    655. <span class="n">height</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    656. <span class="n">width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    657. <span class="n">iou_thresholds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    658. <span class="n">denormalize_targets</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
    659. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    660. <span class="n">crowd_targets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    661. <span class="n">top_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
    662. <span class="n">return_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    663. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">]:</span>
    664. <span class="sd">&quot;&quot;&quot;</span>
    665. <span class="sd"> Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.</span>
    666. <span class="sd"> :param output: list (of length batch_size) of Tensors of shape (num_predictions, 6)</span>
    667. <span class="sd"> format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size</span>
    668. <span class="sd"> :param targets: targets for all images of shape (total_num_targets, 6)</span>
    669. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
    670. <span class="sd"> :param height: dimensions of the image</span>
    671. <span class="sd"> :param width: dimensions of the image</span>
    672. <span class="sd"> :param iou_thresholds: Threshold to compute the mAP</span>
    673. <span class="sd"> :param device: Device</span>
    674. <span class="sd"> :param crowd_targets: crowd targets for all images of shape (total_num_crowd_targets, 6)</span>
    675. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
    676. <span class="sd"> :param top_k: Number of predictions to keep per class, ordered by confidence score</span>
    677. <span class="sd"> :param denormalize_targets: If True, denormalize the targets and crowd_targets</span>
    678. <span class="sd"> :param return_on_cpu: If True, the output will be returned on &quot;CPU&quot;, otherwise it will be returned on &quot;device&quot;</span>
    679. <span class="sd"> :return: list of the following tensors, for every image:</span>
    680. <span class="sd"> :preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
    681. <span class="sd"> True when prediction (i) is matched with a target with respect to the (j)th IoU threshold</span>
    682. <span class="sd"> :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
    683. <span class="sd"> True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold</span>
    684. <span class="sd"> :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction</span>
    685. <span class="sd"> :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction</span>
    686. <span class="sd"> :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target</span>
    687. <span class="sd"> &quot;&quot;&quot;</span>
    688. <span class="n">output</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">tensor</span><span class="p">:</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">output</span><span class="p">)</span>
    689. <span class="n">targets</span><span class="p">,</span> <span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">iou_thresholds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    690. <span class="c1"># If crowd_targets is not provided, we patch it with an empty tensor</span>
    691. <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="k">if</span> <span class="n">crowd_targets</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">crowd_targets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    692. <span class="n">batch_metrics</span> <span class="o">=</span> <span class="p">[]</span>
    693. <span class="k">for</span> <span class="n">img_i</span><span class="p">,</span> <span class="n">img_preds</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">output</span><span class="p">):</span>
    694. <span class="c1"># If img_preds is None (not prediction for this image), we patch it with an empty tensor</span>
    695. <span class="n">img_preds</span> <span class="o">=</span> <span class="n">img_preds</span> <span class="k">if</span> <span class="n">img_preds</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    696. <span class="n">img_targets</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">img_i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span>
    697. <span class="n">img_crowd_targets</span> <span class="o">=</span> <span class="n">crowd_targets</span><span class="p">[</span><span class="n">crowd_targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">img_i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span>
    698. <span class="n">img_matching_tensors</span> <span class="o">=</span> <span class="n">compute_img_detection_matching</span><span class="p">(</span>
    699. <span class="n">preds</span><span class="o">=</span><span class="n">img_preds</span><span class="p">,</span>
    700. <span class="n">targets</span><span class="o">=</span><span class="n">img_targets</span><span class="p">,</span>
    701. <span class="n">crowd_targets</span><span class="o">=</span><span class="n">img_crowd_targets</span><span class="p">,</span>
    702. <span class="n">denormalize_targets</span><span class="o">=</span><span class="n">denormalize_targets</span><span class="p">,</span>
    703. <span class="n">height</span><span class="o">=</span><span class="n">height</span><span class="p">,</span>
    704. <span class="n">width</span><span class="o">=</span><span class="n">width</span><span class="p">,</span>
    705. <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
    706. <span class="n">iou_thresholds</span><span class="o">=</span><span class="n">iou_thresholds</span><span class="p">,</span>
    707. <span class="n">top_k</span><span class="o">=</span><span class="n">top_k</span><span class="p">,</span>
    708. <span class="n">return_on_cpu</span><span class="o">=</span><span class="n">return_on_cpu</span>
    709. <span class="p">)</span>
    710. <span class="n">batch_metrics</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img_matching_tensors</span><span class="p">)</span>
    711. <span class="k">return</span> <span class="n">batch_metrics</span></div>
    712. <div class="viewcode-block" id="compute_img_detection_matching"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_img_detection_matching">[docs]</a><span class="k">def</span> <span class="nf">compute_img_detection_matching</span><span class="p">(</span>
    713. <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    714. <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    715. <span class="n">crowd_targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    716. <span class="n">height</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    717. <span class="n">width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    718. <span class="n">iou_thresholds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    719. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    720. <span class="n">denormalize_targets</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
    721. <span class="n">top_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
    722. <span class="n">return_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
    723. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">:</span>
    724. <span class="sd">&quot;&quot;&quot;</span>
    725. <span class="sd"> Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score</span>
    726. <span class="sd"> for a given image.</span>
    727. <span class="sd"> :param preds: Tensor of shape (num_img_predictions, 6)</span>
    728. <span class="sd"> format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size</span>
    729. <span class="sd"> :param targets: targets for this image of shape (num_img_targets, 6)</span>
    730. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
    731. <span class="sd"> :param height: dimensions of the image</span>
    732. <span class="sd"> :param width: dimensions of the image</span>
    733. <span class="sd"> :param iou_thresholds: Threshold to compute the mAP</span>
    734. <span class="sd"> :param device:</span>
    735. <span class="sd"> :param crowd_targets: crowd targets for all images of shape (total_num_crowd_targets, 6)</span>
    736. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
    737. <span class="sd"> :param top_k: Number of predictions to keep per class, ordered by confidence score</span>
    738. <span class="sd"> :param device: Device</span>
    739. <span class="sd"> :param denormalize_targets: If True, denormalize the targets and crowd_targets</span>
    740. <span class="sd"> :param return_on_cpu: If True, the output will be returned on &quot;CPU&quot;, otherwise it will be returned on &quot;device&quot;</span>
    741. <span class="sd"> :return:</span>
    742. <span class="sd"> :preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
    743. <span class="sd"> True when prediction (i) is matched with a target with respect to the (j)th IoU threshold</span>
    744. <span class="sd"> :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
    745. <span class="sd"> True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold</span>
    746. <span class="sd"> :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction</span>
    747. <span class="sd"> :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction</span>
    748. <span class="sd"> :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target</span>
    749. <span class="sd"> &quot;&quot;&quot;</span>
    750. <span class="n">num_iou_thresholds</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">iou_thresholds</span><span class="p">)</span>
    751. <span class="k">if</span> <span class="n">preds</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    752. <span class="k">if</span> <span class="n">return_on_cpu</span><span class="p">:</span>
    753. <span class="n">device</span> <span class="o">=</span> <span class="s2">&quot;cpu&quot;</span>
    754. <span class="n">preds_matched</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_iou_thresholds</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    755. <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_iou_thresholds</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    756. <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    757. <span class="n">preds_cls</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    758. <span class="n">targets_cls</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    759. <span class="k">return</span> <span class="n">preds_matched</span><span class="p">,</span> <span class="n">preds_to_ignore</span><span class="p">,</span> <span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">targets_cls</span>
    760. <span class="n">preds_matched</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">),</span> <span class="n">num_iou_thresholds</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    761. <span class="n">targets_matched</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">),</span> <span class="n">num_iou_thresholds</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    762. <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">),</span> <span class="n">num_iou_thresholds</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    763. <span class="n">preds_cls</span><span class="p">,</span> <span class="n">preds_box</span><span class="p">,</span> <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">preds</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">preds</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">preds</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span>
    764. <span class="n">targets_cls</span><span class="p">,</span> <span class="n">targets_box</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span>
    765. <span class="n">crowd_targets_cls</span><span class="p">,</span> <span class="n">crowd_target_box</span> <span class="o">=</span> <span class="n">crowd_targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">crowd_targets</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span>
    766. <span class="c1"># Ignore all but the predictions that were top_k for their class</span>
    767. <span class="n">preds_idx_to_use</span> <span class="o">=</span> <span class="n">get_top_k_idx_per_cls</span><span class="p">(</span><span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">top_k</span><span class="p">)</span>
    768. <span class="n">preds_to_ignore</span><span class="p">[:,</span> <span class="p">:]</span> <span class="o">=</span> <span class="kc">True</span>
    769. <span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span>
    770. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    771. <span class="c1"># CHANGE bboxes TO FIT THE IMAGE SIZE</span>
    772. <span class="n">change_bbox_bounds_for_image_size</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">))</span>
    773. <span class="c1"># if target_format == &quot;xywh&quot;:</span>
    774. <span class="n">targets_box</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">targets_box</span><span class="p">)</span> <span class="c1"># cxcywh2xyxy</span>
    775. <span class="n">crowd_target_box</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">crowd_target_box</span><span class="p">)</span> <span class="c1"># convert_xywh_bbox_to_xyxy</span>
    776. <span class="k">if</span> <span class="n">denormalize_targets</span><span class="p">:</span>
    777. <span class="n">targets_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">width</span>
    778. <span class="n">targets_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">height</span>
    779. <span class="n">crowd_target_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">width</span>
    780. <span class="n">crowd_target_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">height</span>
    781. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    782. <span class="c1"># shape = (n_preds x n_targets)</span>
    783. <span class="n">iou</span> <span class="o">=</span> <span class="n">box_iou</span><span class="p">(</span><span class="n">preds_box</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">],</span> <span class="n">targets_box</span><span class="p">)</span>
    784. <span class="c1"># Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class</span>
    785. <span class="c1"># Filling with 0 is equivalent to ignore these values since with want IoU &gt; iou_threshold &gt; 0</span>
    786. <span class="n">cls_mismatch</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">!=</span> <span class="n">targets_cls</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
    787. <span class="n">iou</span><span class="p">[</span><span class="n">cls_mismatch</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
    788. <span class="c1"># The matching priority is first detection confidence and then IoU value.</span>
    789. <span class="c1"># The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.</span>
    790. <span class="n">sorted_iou</span><span class="p">,</span> <span class="n">target_sorted</span> <span class="o">=</span> <span class="n">iou</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    791. <span class="c1"># Only iterate over IoU values higher than min threshold to speed up the process</span>
    792. <span class="k">for</span> <span class="n">pred_selected_i</span><span class="p">,</span> <span class="n">target_sorted_i</span> <span class="ow">in</span> <span class="p">(</span><span class="n">sorted_iou</span> <span class="o">&gt;</span> <span class="n">iou_thresholds</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    793. <span class="c1"># pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes</span>
    794. <span class="n">pred_i</span> <span class="o">=</span> <span class="n">preds_idx_to_use</span><span class="p">[</span><span class="n">pred_selected_i</span><span class="p">]</span>
    795. <span class="n">target_i</span> <span class="o">=</span> <span class="n">target_sorted</span><span class="p">[</span><span class="n">pred_selected_i</span><span class="p">,</span> <span class="n">target_sorted_i</span><span class="p">]</span>
    796. <span class="c1"># Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold</span>
    797. <span class="n">is_iou_above_threshold</span> <span class="o">=</span> <span class="n">sorted_iou</span><span class="p">[</span><span class="n">pred_selected_i</span><span class="p">,</span> <span class="n">target_sorted_i</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">iou_thresholds</span>
    798. <span class="c1"># Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold</span>
    799. <span class="n">are_candidates_free</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="o">~</span><span class="n">preds_matched</span><span class="p">[</span><span class="n">pred_i</span><span class="p">,</span> <span class="p">:],</span> <span class="o">~</span><span class="n">targets_matched</span><span class="p">[</span><span class="n">target_i</span><span class="p">,</span> <span class="p">:])</span>
    800. <span class="c1"># Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold</span>
    801. <span class="n">are_candidates_good</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">is_iou_above_threshold</span><span class="p">,</span> <span class="n">are_candidates_free</span><span class="p">)</span>
    802. <span class="c1"># For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )</span>
    803. <span class="c1"># fill the matching placeholders with True</span>
    804. <span class="n">targets_matched</span><span class="p">[</span><span class="n">target_i</span><span class="p">,</span> <span class="n">are_candidates_good</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
    805. <span class="n">preds_matched</span><span class="p">[</span><span class="n">pred_i</span><span class="p">,</span> <span class="n">are_candidates_good</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
    806. <span class="c1"># When all the targets are matched with a prediction for every IoU Threshold, stop.</span>
    807. <span class="k">if</span> <span class="n">targets_matched</span><span class="o">.</span><span class="n">all</span><span class="p">():</span>
    808. <span class="k">break</span>
    809. <span class="c1"># Crowd targets can be matched with many predictions.</span>
    810. <span class="c1"># Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.</span>
    811. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    812. <span class="c1"># shape = (n_preds_to_use x n_crowd_targets)</span>
    813. <span class="n">ioa</span> <span class="o">=</span> <span class="n">crowd_ioa</span><span class="p">(</span><span class="n">preds_box</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">],</span> <span class="n">crowd_target_box</span><span class="p">)</span>
    814. <span class="c1"># Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class</span>
    815. <span class="c1"># Filling with 0 is equivalent to ignore these values since with want IoA &gt; threshold &gt; 0</span>
    816. <span class="n">cls_mismatch</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">!=</span> <span class="n">crowd_targets_cls</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
    817. <span class="n">ioa</span><span class="p">[</span><span class="n">cls_mismatch</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
    818. <span class="c1"># For each prediction, we keep it&#39;s highest score with any crowd target (of same class)</span>
    819. <span class="c1"># shape = (n_preds_to_use)</span>
    820. <span class="n">best_ioa</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">ioa</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    821. <span class="c1"># If a prediction has IoA higher than threshold (with any target of same class), then there is a match</span>
    822. <span class="c1"># shape = (n_preds_to_use x iou_thresholds)</span>
    823. <span class="n">is_matching_with_crowd</span> <span class="o">=</span> <span class="p">(</span><span class="n">best_ioa</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">iou_thresholds</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
    824. <span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_or</span><span class="p">(</span><span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">],</span> <span class="n">is_matching_with_crowd</span><span class="p">)</span>
    825. <span class="k">if</span> <span class="n">return_on_cpu</span><span class="p">:</span>
    826. <span class="n">preds_matched</span> <span class="o">=</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    827. <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">preds_to_ignore</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    828. <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    829. <span class="n">preds_cls</span> <span class="o">=</span> <span class="n">preds_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    830. <span class="n">targets_cls</span> <span class="o">=</span> <span class="n">targets_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    831. <span class="k">return</span> <span class="n">preds_matched</span><span class="p">,</span> <span class="n">preds_to_ignore</span><span class="p">,</span> <span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">targets_cls</span></div>
    832. <div class="viewcode-block" id="get_top_k_idx_per_cls"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.get_top_k_idx_per_cls">[docs]</a><span class="k">def</span> <span class="nf">get_top_k_idx_per_cls</span><span class="p">(</span><span class="n">preds_scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">top_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    833. <span class="sd">&quot;&quot;&quot;Get the indexes of all the top k predictions for every class</span>
    834. <span class="sd"> :param preds_scores: The confidence scores, vector of shape (n_pred)</span>
    835. <span class="sd"> :param preds_cls: The predicted class, vector of shape (n_pred)</span>
    836. <span class="sd"> :param top_k: Number of predictions to keep per class, ordered by confidence score</span>
    837. <span class="sd"> :return top_k_idx: Indexes of the top k predictions. length &lt;= (k * n_unique_class)</span>
    838. <span class="sd"> &quot;&quot;&quot;</span>
    839. <span class="n">n_unique_cls</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">preds_cls</span><span class="p">)</span>
    840. <span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n_unique_cls</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">preds_scores</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
    841. <span class="n">preds_scores_per_cls</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">mask</span>
    842. <span class="n">sorted_scores_per_cls</span><span class="p">,</span> <span class="n">sorting_idx</span> <span class="o">=</span> <span class="n">preds_scores_per_cls</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    843. <span class="n">idx_with_satisfying_scores</span> <span class="o">=</span> <span class="n">sorted_scores_per_cls</span><span class="p">[:</span><span class="n">top_k</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    844. <span class="n">top_k_idx</span> <span class="o">=</span> <span class="n">sorting_idx</span><span class="p">[</span><span class="n">idx_with_satisfying_scores</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
    845. <span class="k">return</span> <span class="n">top_k_idx</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></div>
    846. <div class="viewcode-block" id="compute_detection_metrics"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_detection_metrics">[docs]</a><span class="k">def</span> <span class="nf">compute_detection_metrics</span><span class="p">(</span>
    847. <span class="n">preds_matched</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    848. <span class="n">preds_to_ignore</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    849. <span class="n">preds_scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    850. <span class="n">preds_cls</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    851. <span class="n">targets_cls</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    852. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    853. <span class="n">recall_thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    854. <span class="n">score_threshold</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    855. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">:</span>
    856. <span class="sd">&quot;&quot;&quot;</span>
    857. <span class="sd"> Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.</span>
    858. <span class="sd"> :param preds_matched: Tensor of shape (num_predictions, n_iou_thresholds)</span>
    859. <span class="sd"> True when prediction (i) is matched with a target with respect to the (j)th IoU threshold</span>
    860. <span class="sd"> :param preds_to_ignore Tensor of shape (num_predictions, n_iou_thresholds)</span>
    861. <span class="sd"> True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold</span>
    862. <span class="sd"> :param preds_scores: Tensor of shape (num_predictions), confidence score for every prediction</span>
    863. <span class="sd"> :param preds_cls: Tensor of shape (num_predictions), predicted class for every prediction</span>
    864. <span class="sd"> :param targets_cls: Tensor of shape (num_targets), ground truth class for every target box to be detected</span>
    865. <span class="sd"> :param recall_thresholds: Recall thresholds used to compute MaP.</span>
    866. <span class="sd"> :param score_threshold: Minimum confidence score to consider a prediction for the computation of</span>
    867. <span class="sd"> precision, recall and f1 (not MaP)</span>
    868. <span class="sd"> :param device: Device</span>
    869. <span class="sd"> :return:</span>
    870. <span class="sd"> :ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs)</span>
    871. <span class="sd"> :unique_classes: Vector with all unique target classes</span>
    872. <span class="sd"> &quot;&quot;&quot;</span>
    873. <span class="n">preds_matched</span><span class="p">,</span> <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">preds_to_ignore</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    874. <span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">targets_cls</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">preds_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">targets_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    875. <span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="k">if</span> <span class="n">recall_thresholds</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">recall_thresholds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    876. <span class="n">unique_classes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">targets_cls</span><span class="p">)</span>
    877. <span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">unique_classes</span><span class="p">),</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    878. <span class="n">ap</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    879. <span class="n">precision</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    880. <span class="n">recall</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    881. <span class="k">for</span> <span class="n">cls_i</span><span class="p">,</span> <span class="bp">cls</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">unique_classes</span><span class="p">):</span>
    882. <span class="n">cls_preds_idx</span><span class="p">,</span> <span class="n">cls_targets_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span> <span class="o">==</span> <span class="bp">cls</span><span class="p">),</span> <span class="p">(</span><span class="n">targets_cls</span> <span class="o">==</span> <span class="bp">cls</span><span class="p">)</span>
    883. <span class="n">cls_ap</span><span class="p">,</span> <span class="n">cls_precision</span><span class="p">,</span> <span class="n">cls_recall</span> <span class="o">=</span> <span class="n">compute_detection_metrics_per_cls</span><span class="p">(</span>
    884. <span class="n">preds_matched</span><span class="o">=</span><span class="n">preds_matched</span><span class="p">[</span><span class="n">cls_preds_idx</span><span class="p">],</span>
    885. <span class="n">preds_to_ignore</span><span class="o">=</span><span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">cls_preds_idx</span><span class="p">],</span>
    886. <span class="n">preds_scores</span><span class="o">=</span><span class="n">preds_scores</span><span class="p">[</span><span class="n">cls_preds_idx</span><span class="p">],</span>
    887. <span class="n">n_targets</span><span class="o">=</span><span class="n">cls_targets_idx</span><span class="o">.</span><span class="n">sum</span><span class="p">(),</span>
    888. <span class="n">recall_thresholds</span><span class="o">=</span><span class="n">recall_thresholds</span><span class="p">,</span>
    889. <span class="n">score_threshold</span><span class="o">=</span><span class="n">score_threshold</span><span class="p">,</span>
    890. <span class="n">device</span><span class="o">=</span><span class="n">device</span>
    891. <span class="p">)</span>
    892. <span class="n">ap</span><span class="p">[</span><span class="n">cls_i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">cls_ap</span>
    893. <span class="n">precision</span><span class="p">[</span><span class="n">cls_i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">cls_precision</span>
    894. <span class="n">recall</span><span class="p">[</span><span class="n">cls_i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">cls_recall</span>
    895. <span class="n">f1</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">precision</span> <span class="o">*</span> <span class="n">recall</span> <span class="o">/</span> <span class="p">(</span><span class="n">precision</span> <span class="o">+</span> <span class="n">recall</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span>
    896. <span class="k">return</span> <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">unique_classes</span></div>
    897. <div class="viewcode-block" id="compute_detection_metrics_per_cls"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_detection_metrics_per_cls">[docs]</a><span class="k">def</span> <span class="nf">compute_detection_metrics_per_cls</span><span class="p">(</span>
    898. <span class="n">preds_matched</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    899. <span class="n">preds_to_ignore</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    900. <span class="n">preds_scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    901. <span class="n">n_targets</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    902. <span class="n">recall_thresholds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    903. <span class="n">score_threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
    904. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    905. <span class="p">):</span>
    906. <span class="sd">&quot;&quot;&quot;</span>
    907. <span class="sd"> Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.</span>
    908. <span class="sd"> :param preds_matched: Tensor of shape (num_predictions, n_iou_thresholds)</span>
    909. <span class="sd"> True when prediction (i) is matched with a target</span>
    910. <span class="sd"> with respect to the(j)th IoU threshold</span>
    911. <span class="sd"> :param preds_to_ignore Tensor of shape (num_predictions, n_iou_thresholds)</span>
    912. <span class="sd"> True when prediction (i) is matched with a crowd target</span>
    913. <span class="sd"> with respect to the (j)th IoU threshold</span>
    914. <span class="sd"> :param preds_scores: Tensor of shape (num_predictions), confidence score for every prediction</span>
    915. <span class="sd"> :param n_targets: Number of target boxes of this class</span>
    916. <span class="sd"> :param recall_thresholds: Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP</span>
    917. <span class="sd"> :param score_threshold: Minimum confidence score to consider a prediction for the computation of</span>
    918. <span class="sd"> precision and recall (not MaP)</span>
    919. <span class="sd"> :param device: Device</span>
    920. <span class="sd"> :return ap, precision, recall: Tensors of shape (nb_iou_thrs)</span>
    921. <span class="sd"> &quot;&quot;&quot;</span>
    922. <span class="n">nb_iou_thrs</span> <span class="o">=</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    923. <span class="n">tps</span> <span class="o">=</span> <span class="n">preds_matched</span>
    924. <span class="n">fps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">preds_matched</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">preds_to_ignore</span><span class="p">))</span>
    925. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">tps</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    926. <span class="k">return</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">nb_iou_thrs</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
    927. <span class="c1"># Sort by decreasing score</span>
    928. <span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">uint8</span> <span class="k">if</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">is</span> <span class="n">torch</span><span class="o">.</span><span class="n">bool</span> <span class="k">else</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">dtype</span>
    929. <span class="n">sort_ind</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">preds_scores</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    930. <span class="n">tps</span> <span class="o">=</span> <span class="n">tps</span><span class="p">[</span><span class="n">sort_ind</span><span class="p">,</span> <span class="p">:]</span>
    931. <span class="n">fps</span> <span class="o">=</span> <span class="n">fps</span><span class="p">[</span><span class="n">sort_ind</span><span class="p">,</span> <span class="p">:]</span>
    932. <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="p">[</span><span class="n">sort_ind</span><span class="p">]</span>
    933. <span class="c1"># Rolling sum over the predictions</span>
    934. <span class="n">rolling_tps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">tps</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
    935. <span class="n">rolling_fps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">fps</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
    936. <span class="n">rolling_recalls</span> <span class="o">=</span> <span class="n">rolling_tps</span> <span class="o">/</span> <span class="n">n_targets</span>
    937. <span class="n">rolling_precisions</span> <span class="o">=</span> <span class="n">rolling_tps</span> <span class="o">/</span> <span class="p">(</span><span class="n">rolling_tps</span> <span class="o">+</span> <span class="n">rolling_fps</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">finfo</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>
    938. <span class="c1"># Reversed cummax to only have decreasing values</span>
    939. <span class="n">rolling_precisions</span> <span class="o">=</span> <span class="n">rolling_precisions</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">cummax</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    940. <span class="c1"># ==================</span>
    941. <span class="c1"># RECALL &amp; PRECISION</span>
    942. <span class="c1"># We want the rolling precision/recall at index i so that: preds_scores[i-1] &gt;= score_threshold &gt; preds_scores[i]</span>
    943. <span class="c1"># Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with &quot;-&quot;</span>
    944. <span class="n">lowest_score_above_threshold</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="o">-</span><span class="n">preds_scores</span><span class="p">,</span> <span class="o">-</span><span class="n">score_threshold</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    945. <span class="k">if</span> <span class="n">lowest_score_above_threshold</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># Here score_threshold &gt; preds_scores[0], so no pred is above the threshold</span>
    946. <span class="n">recall</span> <span class="o">=</span> <span class="mi">0</span>
    947. <span class="n">precision</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># the precision is not really defined when no pred but we need to give it a value</span>
    948. <span class="k">else</span><span class="p">:</span>
    949. <span class="n">recall</span> <span class="o">=</span> <span class="n">rolling_recalls</span><span class="p">[</span><span class="n">lowest_score_above_threshold</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
    950. <span class="n">precision</span> <span class="o">=</span> <span class="n">rolling_precisions</span><span class="p">[</span><span class="n">lowest_score_above_threshold</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
    951. <span class="c1"># ==================</span>
    952. <span class="c1"># AVERAGE PRECISION</span>
    953. <span class="c1"># shape = (nb_iou_thrs, n_recall_thresholds)</span>
    954. <span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">recall_thresholds</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">nb_iou_thrs</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    955. <span class="c1"># We want the index i so that: rolling_recalls[i-1] &lt; recall_thresholds[k] &lt;= rolling_recalls[i]</span>
    956. <span class="c1"># Note: when recall_thresholds[k] &gt; max(rolling_recalls), i = len(rolling_recalls)</span>
    957. <span class="c1"># Note2: we work with transpose (.T) to apply torch.searchsorted on first dim instead of the last one</span>
    958. <span class="n">recall_threshold_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">rolling_recalls</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">recall_thresholds</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
    959. <span class="c1"># When recall_thresholds[k] &gt; max(rolling_recalls), rolling_precisions[i] is not defined, and we want precision = 0</span>
    960. <span class="n">rolling_precisions</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">rolling_precisions</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    961. <span class="c1"># shape = (n_recall_thresholds, nb_iou_thrs)</span>
    962. <span class="n">sampled_precision_points</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">rolling_precisions</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="n">recall_threshold_idx</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    963. <span class="c1"># Average over the recall_thresholds</span>
    964. <span class="n">ap</span> <span class="o">=</span> <span class="n">sampled_precision_points</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    965. <span class="k">return</span> <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span></div>
    966. </pre></div>
    967. </div>
    968. </div>
    969. <footer>
    970. <hr/>
    971. <div role="contentinfo">
    972. <p>&#169; Copyright 2021, SuperGradients team.</p>
    973. </div>
    974. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    975. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    976. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    977. </footer>
    978. </div>
    979. </div>
    980. </section>
    981. </div>
    982. <script>
    983. jQuery(function () {
    984. SphinxRtdTheme.Navigation.enable(true);
    985. });
    986. </script>
    987. </body>
    988. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.distributed_training_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.distributed_training_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.distributed_training_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">distributed</span> <span class="k">as</span> <span class="n">dist</span>
    85. <span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">autocast</span>
    86. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
    87. <span class="kn">import</span> <span class="nn">itertools</span>
    88. <span class="kn">from</span> <span class="nn">contextlib</span> <span class="kn">import</span> <span class="n">contextmanager</span>
    89. <div class="viewcode-block" id="distributed_all_reduce_tensor_average"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.distributed_all_reduce_tensor_average">[docs]</a><span class="k">def</span> <span class="nf">distributed_all_reduce_tensor_average</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">n</span><span class="p">):</span>
    90. <span class="sd">&quot;&quot;&quot;</span>
    91. <span class="sd"> This method performs a reduce operation on multiple nodes running distributed training</span>
    92. <span class="sd"> It first sums all of the results and then divides the summation</span>
    93. <span class="sd"> :param tensor: The tensor to perform the reduce operation for</span>
    94. <span class="sd"> :param n: Number of nodes</span>
    95. <span class="sd"> :return: Averaged tensor from all of the nodes</span>
    96. <span class="sd"> &quot;&quot;&quot;</span>
    97. <span class="n">rt</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
    98. <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">rt</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
    99. <span class="n">rt</span> <span class="o">/=</span> <span class="n">n</span>
    100. <span class="k">return</span> <span class="n">rt</span></div>
    101. <div class="viewcode-block" id="reduce_results_tuple_for_ddp"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.reduce_results_tuple_for_ddp">[docs]</a><span class="k">def</span> <span class="nf">reduce_results_tuple_for_ddp</span><span class="p">(</span><span class="n">validation_results_tuple</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
    102. <span class="sd">&quot;&quot;&quot;Gather all validation tuples from the various devices and average them&quot;&quot;&quot;</span>
    103. <span class="n">validation_results_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">validation_results_tuple</span><span class="p">)</span>
    104. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">validation_result</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">validation_results_list</span><span class="p">):</span>
    105. <span class="n">validation_results_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">distributed_all_reduce_tensor_average</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">validation_result</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
    106. <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">get_world_size</span><span class="p">())</span>
    107. <span class="n">validation_results_tuple</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">validation_results_list</span><span class="p">)</span>
    108. <span class="k">return</span> <span class="n">validation_results_tuple</span></div>
    109. <div class="viewcode-block" id="MultiGPUModeAutocastWrapper"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.MultiGPUModeAutocastWrapper">[docs]</a><span class="k">class</span> <span class="nc">MultiGPUModeAutocastWrapper</span><span class="p">():</span>
    110. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func</span><span class="p">):</span>
    111. <span class="bp">self</span><span class="o">.</span><span class="n">func</span> <span class="o">=</span> <span class="n">func</span>
    112. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    113. <span class="k">with</span> <span class="n">autocast</span><span class="p">():</span>
    114. <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">func</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    115. <span class="k">return</span> <span class="n">out</span></div>
    116. <div class="viewcode-block" id="scaled_all_reduce"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.scaled_all_reduce">[docs]</a><span class="k">def</span> <span class="nf">scaled_all_reduce</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    117. <span class="sd">&quot;&quot;&quot;</span>
    118. <span class="sd"> Performs the scaled all_reduce operation on the provided tensors.</span>
    119. <span class="sd"> The input tensors are modified in-place.</span>
    120. <span class="sd"> Currently supports only the sum</span>
    121. <span class="sd"> reduction operator.</span>
    122. <span class="sd"> The reduced values are scaled by the inverse size of the</span>
    123. <span class="sd"> process group (equivalent to num_gpus).</span>
    124. <span class="sd"> &quot;&quot;&quot;</span>
    125. <span class="c1"># There is no need for reduction in the single-proc case</span>
    126. <span class="k">if</span> <span class="n">num_gpus</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
    127. <span class="k">return</span> <span class="n">tensors</span>
    128. <span class="c1"># Queue the reductions</span>
    129. <span class="n">reductions</span> <span class="o">=</span> <span class="p">[]</span>
    130. <span class="k">for</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
    131. <span class="n">reduction</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">async_op</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    132. <span class="n">reductions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">reduction</span><span class="p">)</span>
    133. <span class="c1"># Wait for reductions to finish</span>
    134. <span class="k">for</span> <span class="n">reduction</span> <span class="ow">in</span> <span class="n">reductions</span><span class="p">:</span>
    135. <span class="n">reduction</span><span class="o">.</span><span class="n">wait</span><span class="p">()</span>
    136. <span class="c1"># Scale the results</span>
    137. <span class="k">for</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
    138. <span class="n">tensor</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">num_gpus</span><span class="p">)</span>
    139. <span class="k">return</span> <span class="n">tensors</span></div>
    140. <div class="viewcode-block" id="compute_precise_bn_stats"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.compute_precise_bn_stats">[docs]</a><span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
    141. <span class="k">def</span> <span class="nf">compute_precise_bn_stats</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">loader</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">,</span> <span class="n">precise_bn_batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    142. <span class="sd">&#39;&#39;&#39;</span>
    143. <span class="sd"> :param model: The model being trained (ie: SgModel.net)</span>
    144. <span class="sd"> :param loader: Training dataloader (ie: SgModel.train_loader)</span>
    145. <span class="sd"> :param precise_bn_batch_size: The effective batch size we want to calculate the batchnorm on. For example, if we are training a model</span>
    146. <span class="sd"> on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192</span>
    147. <span class="sd"> (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).</span>
    148. <span class="sd"> If precise_bn_batch_size is not provided in the training_params, the latter heuristic</span>
    149. <span class="sd"> will be taken.</span>
    150. <span class="sd"> param num_gpus: The number of gpus we are training on</span>
    151. <span class="sd"> &#39;&#39;&#39;</span>
    152. <span class="c1"># Compute the number of minibatches to use</span>
    153. <span class="n">num_iter</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">precise_bn_batch_size</span> <span class="o">/</span> <span class="p">(</span><span class="n">loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_gpus</span><span class="p">))</span> <span class="k">if</span> <span class="n">precise_bn_batch_size</span> <span class="k">else</span> <span class="n">num_gpus</span>
    154. <span class="n">num_iter</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_iter</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">loader</span><span class="p">))</span>
    155. <span class="c1"># Retrieve the BN layers</span>
    156. <span class="n">bns</span> <span class="o">=</span> <span class="p">[</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">()</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">)]</span>
    157. <span class="c1"># Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))</span>
    158. <span class="n">running_means</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">bn</span><span class="o">.</span><span class="n">running_mean</span><span class="p">)</span> <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">]</span>
    159. <span class="n">running_vars</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">bn</span><span class="o">.</span><span class="n">running_var</span><span class="p">)</span> <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">]</span>
    160. <span class="c1"># Remember momentum values</span>
    161. <span class="n">momentums</span> <span class="o">=</span> <span class="p">[</span><span class="n">bn</span><span class="o">.</span><span class="n">momentum</span> <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">]</span>
    162. <span class="c1"># Set momentum to 1.0 to compute BN stats that only reflect the current batch</span>
    163. <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">:</span>
    164. <span class="n">bn</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="mf">1.0</span>
    165. <span class="c1"># Average the BN stats for each BN layer over the batches</span>
    166. <span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">_labels</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">islice</span><span class="p">(</span><span class="n">loader</span><span class="p">,</span> <span class="n">num_iter</span><span class="p">):</span>
    167. <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">cuda</span><span class="p">())</span>
    168. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">bn</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">bns</span><span class="p">):</span>
    169. <span class="n">running_means</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">bn</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">/</span> <span class="n">num_iter</span>
    170. <span class="n">running_vars</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">bn</span><span class="o">.</span><span class="n">running_var</span> <span class="o">/</span> <span class="n">num_iter</span>
    171. <span class="c1"># Sync BN stats across GPUs (no reduction if 1 GPU used)</span>
    172. <span class="n">running_means</span> <span class="o">=</span> <span class="n">scaled_all_reduce</span><span class="p">(</span><span class="n">running_means</span><span class="p">,</span> <span class="n">num_gpus</span><span class="o">=</span><span class="n">num_gpus</span><span class="p">)</span>
    173. <span class="n">running_vars</span> <span class="o">=</span> <span class="n">scaled_all_reduce</span><span class="p">(</span><span class="n">running_vars</span><span class="p">,</span> <span class="n">num_gpus</span><span class="o">=</span><span class="n">num_gpus</span><span class="p">)</span>
    174. <span class="c1"># Set BN stats and restore original momentum values</span>
    175. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">bn</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">bns</span><span class="p">):</span>
    176. <span class="n">bn</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">=</span> <span class="n">running_means</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    177. <span class="n">bn</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="n">running_vars</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    178. <span class="n">bn</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentums</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></div>
    179. <div class="viewcode-block" id="get_local_rank"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.get_local_rank">[docs]</a><span class="k">def</span> <span class="nf">get_local_rank</span><span class="p">():</span>
    180. <span class="sd">&quot;&quot;&quot;</span>
    181. <span class="sd"> Returns the local rank if running in DDP, and 0 otherwise</span>
    182. <span class="sd"> :return: local rank</span>
    183. <span class="sd"> &quot;&quot;&quot;</span>
    184. <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_rank</span><span class="p">()</span> <span class="k">if</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_initialized</span><span class="p">()</span> <span class="k">else</span> <span class="mi">0</span></div>
    185. <div class="viewcode-block" id="get_world_size"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.get_world_size">[docs]</a><span class="k">def</span> <span class="nf">get_world_size</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
    186. <span class="sd">&quot;&quot;&quot;</span>
    187. <span class="sd"> Returns the world size if running in DDP, and 1 otherwise</span>
    188. <span class="sd"> :return: world size</span>
    189. <span class="sd"> &quot;&quot;&quot;</span>
    190. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
    191. <span class="k">return</span> <span class="mi">1</span>
    192. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_initialized</span><span class="p">():</span>
    193. <span class="k">return</span> <span class="mi">1</span>
    194. <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_world_size</span><span class="p">()</span></div>
    195. <div class="viewcode-block" id="wait_for_the_master"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.wait_for_the_master">[docs]</a><span class="nd">@contextmanager</span>
    196. <span class="k">def</span> <span class="nf">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    197. <span class="sd">&quot;&quot;&quot;</span>
    198. <span class="sd"> Make all processes waiting for the master to do some task.</span>
    199. <span class="sd"> &quot;&quot;&quot;</span>
    200. <span class="k">if</span> <span class="n">local_rank</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    201. <span class="n">dist</span><span class="o">.</span><span class="n">barrier</span><span class="p">()</span>
    202. <span class="k">yield</span>
    203. <span class="k">if</span> <span class="n">local_rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    204. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
    205. <span class="k">return</span>
    206. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_initialized</span><span class="p">():</span>
    207. <span class="k">return</span>
    208. <span class="k">else</span><span class="p">:</span>
    209. <span class="n">dist</span><span class="o">.</span><span class="n">barrier</span><span class="p">()</span></div>
    210. </pre></div>
    211. </div>
    212. </div>
    213. <footer>
    214. <hr/>
    215. <div role="contentinfo">
    216. <p>&#169; Copyright 2021, SuperGradients team.</p>
    217. </div>
    218. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    219. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    220. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    221. </footer>
    222. </div>
    223. </div>
    224. </section>
    225. </div>
    226. <script>
    227. jQuery(function () {
    228. SphinxRtdTheme.Navigation.enable(true);
    229. });
    230. </script>
    231. </body>
    232. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.early_stopping &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.early_stopping</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.early_stopping</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">PhaseCallback</span><span class="p">,</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">PhaseContext</span>
    84. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
    85. <span class="kn">import</span> <span class="nn">torch</span>
    86. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    87. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    88. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    89. <div class="viewcode-block" id="EarlyStop"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.early_stopping.EarlyStop">[docs]</a><span class="k">class</span> <span class="nc">EarlyStop</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
    90. <span class="sd">&quot;&quot;&quot;</span>
    91. <span class="sd"> Callback to monitor a metric and stop training when it stops improving.</span>
    92. <span class="sd"> Inspired by pytorch_lightning.callbacks.early_stopping and tf.keras.callbacks.EarlyStopping</span>
    93. <span class="sd"> &quot;&quot;&quot;</span>
    94. <span class="n">mode_dict</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;min&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">lt</span><span class="p">,</span> <span class="s2">&quot;max&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">gt</span><span class="p">}</span>
    95. <span class="n">supported_phases</span> <span class="o">=</span> <span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">VALIDATION_EPOCH_END</span><span class="p">,</span> <span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_END</span><span class="p">)</span>
    96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    97. <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span>
    98. <span class="n">monitor</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
    99. <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;min&quot;</span><span class="p">,</span>
    100. <span class="n">min_delta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
    101. <span class="n">patience</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span>
    102. <span class="n">check_finite</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    103. <span class="n">threshold</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    104. <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    105. <span class="n">strict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
    106. <span class="p">):</span>
    107. <span class="sd">&quot;&quot;&quot;</span>
    108. <span class="sd"> :param phase: Callback phase event.</span>
    109. <span class="sd"> :param monitor: name of the metric to be monitored.</span>
    110. <span class="sd"> :param mode: one of &#39;min&#39;, &#39;max&#39;. In &#39;min&#39; mode, training will stop when the quantity</span>
    111. <span class="sd"> monitored has stopped decreasing and in &#39;max&#39; mode it will stop when the quantity</span>
    112. <span class="sd"> monitored has stopped increasing.</span>
    113. <span class="sd"> :param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute</span>
    114. <span class="sd"> change of less than `min_delta`, will count as no improvement.</span>
    115. <span class="sd"> :param patience: number of checks with no improvement after which training will be stopped.</span>
    116. <span class="sd"> One check happens after every phase event.</span>
    117. <span class="sd"> :param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.</span>
    118. <span class="sd"> :param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode &#39;min&#39;</span>
    119. <span class="sd"> stops training when below threshold, For mode &#39;max&#39; stops training when above threshold.</span>
    120. <span class="sd"> :param verbose: If `True` print logs.</span>
    121. <span class="sd"> :param strict: whether to crash the training if `monitor` is not found in the metrics.</span>
    122. <span class="sd"> &quot;&quot;&quot;</span>
    123. <span class="nb">super</span><span class="p">(</span><span class="n">EarlyStop</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
    124. <span class="k">if</span> <span class="n">phase</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">supported_phases</span><span class="p">:</span>
    125. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;EarlyStop doesn&#39;t support phase: </span><span class="si">{</span><span class="n">phase</span><span class="si">}</span><span class="s2">, &quot;</span>
    126. <span class="sa">f</span><span class="s2">&quot;excepted </span><span class="si">{</span><span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="nb">str</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">supported_phases</span><span class="p">])</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    127. <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">=</span> <span class="n">phase</span>
    128. <span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span> <span class="o">=</span> <span class="n">monitor</span>
    129. <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">=</span> <span class="n">min_delta</span>
    130. <span class="bp">self</span><span class="o">.</span><span class="n">patience</span> <span class="o">=</span> <span class="n">patience</span>
    131. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
    132. <span class="bp">self</span><span class="o">.</span><span class="n">check_finite</span> <span class="o">=</span> <span class="n">check_finite</span>
    133. <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span> <span class="o">=</span> <span class="n">threshold</span>
    134. <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">=</span> <span class="n">verbose</span>
    135. <span class="bp">self</span><span class="o">.</span><span class="n">strict</span> <span class="o">=</span> <span class="n">strict</span>
    136. <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">=</span> <span class="mi">0</span>
    137. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode_dict</span><span class="p">:</span>
    138. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;`mode` can be </span><span class="si">{</span><span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span><span class="si">}</span><span class="s2">, got </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    139. <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">]</span>
    140. <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">*=</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">gt</span> <span class="k">else</span> <span class="o">-</span><span class="mi">1</span>
    141. <span class="n">torch_inf</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">Inf</span><span class="p">)</span>
    142. <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">=</span> <span class="n">torch_inf</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">lt</span> <span class="k">else</span> <span class="o">-</span><span class="n">torch_inf</span>
    143. <span class="k">def</span> <span class="nf">_get_metric_value</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">metrics_dict</span><span class="p">):</span>
    144. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    145. <span class="n">msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Can&#39;t find EarlyStop monitor </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> in metrics_dict: </span><span class="si">{</span><span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span>
    146. <span class="n">exception_cls</span> <span class="o">=</span> <span class="ne">RuntimeError</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">strict</span> <span class="k">else</span> <span class="n">MissingMonitorKeyException</span>
    147. <span class="k">raise</span> <span class="n">exception_cls</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span>
    148. <span class="k">return</span> <span class="n">metrics_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="p">]</span>
    149. <span class="k">def</span> <span class="nf">_check_for_early_stop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">current</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    150. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">False</span>
    151. <span class="c1"># check if current value is Nan or inf</span>
    152. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">check_finite</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">isfinite</span><span class="p">(</span><span class="n">current</span><span class="p">):</span>
    153. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">True</span>
    154. <span class="n">reason</span> <span class="o">=</span> <span class="p">(</span>
    155. <span class="sa">f</span><span class="s2">&quot;Monitored metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2"> is not finite.&quot;</span>
    156. <span class="sa">f</span><span class="s2">&quot; Previous best value was </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">. Signaling Trainer to stop.&quot;</span>
    157. <span class="p">)</span>
    158. <span class="c1"># check if current value reached threshold value</span>
    159. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="p">(</span><span class="n">current</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">):</span>
    160. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">True</span>
    161. <span class="n">reason</span> <span class="o">=</span> <span class="p">(</span>
    162. <span class="s2">&quot;Stopping threshold reached:&quot;</span>
    163. <span class="sa">f</span><span class="s2">&quot; </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="si">}</span><span class="s2">.&quot;</span>
    164. <span class="s2">&quot; Signaling Trainer to stop.&quot;</span>
    165. <span class="p">)</span>
    166. <span class="c1"># check if current is an improvement of monitor_key metric.</span>
    167. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="p">(</span><span class="n">current</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">current</span><span class="o">.</span><span class="n">device</span><span class="p">)):</span>
    168. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">False</span>
    169. <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">isfinite</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="p">):</span>
    170. <span class="n">reason</span> <span class="o">=</span> <span class="p">(</span>
    171. <span class="sa">f</span><span class="s2">&quot;Metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> improved by </span><span class="si">{</span><span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">-</span> <span class="n">current</span><span class="p">)</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> &gt;=&quot;</span>
    172. <span class="sa">f</span><span class="s2">&quot; min_delta = </span><span class="si">{</span><span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span><span class="p">)</span><span class="si">}</span><span class="s2">. New best score: </span><span class="si">{</span><span class="n">current</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">&quot;</span>
    173. <span class="p">)</span>
    174. <span class="k">else</span><span class="p">:</span>
    175. <span class="n">reason</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> improved. New best score: </span><span class="si">{</span><span class="n">current</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">&quot;</span>
    176. <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">=</span> <span class="n">current</span>
    177. <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">=</span> <span class="mi">0</span>
    178. <span class="c1"># no improvement in monitor_key metric, check if wait_count is bigger than patience.</span>
    179. <span class="k">else</span><span class="p">:</span>
    180. <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">+=</span> <span class="mi">1</span>
    181. <span class="n">reason</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Monitored metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> did not improve in the last </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span><span class="si">}</span><span class="s2"> records.&quot;</span>
    182. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patience</span><span class="p">:</span>
    183. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">True</span>
    184. <span class="n">reason</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot; Best score: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">. Signaling Trainer to stop.&quot;</span>
    185. <span class="k">return</span> <span class="n">reason</span><span class="p">,</span> <span class="n">should_stop</span>
    186. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
    187. <span class="k">try</span><span class="p">:</span>
    188. <span class="n">current</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_metric_value</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="p">)</span>
    189. <span class="k">except</span> <span class="n">MissingMonitorKeyException</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    190. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>
    191. <span class="k">return</span>
    192. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">current</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    193. <span class="n">current</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">current</span><span class="p">)</span>
    194. <span class="n">reason</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_for_early_stop</span><span class="p">(</span><span class="n">current</span><span class="p">)</span>
    195. <span class="c1"># log reason message, and signal early stop if should_stop=True.</span>
    196. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_stop</span><span class="p">:</span>
    197. <span class="bp">self</span><span class="o">.</span><span class="n">_signal_early_stop</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">reason</span><span class="p">)</span>
    198. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span><span class="p">:</span>
    199. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">reason</span><span class="p">)</span>
    200. <span class="k">def</span> <span class="nf">_signal_early_stop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">,</span> <span class="n">reason</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    201. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">reason</span><span class="p">)</span>
    202. <span class="n">context</span><span class="o">.</span><span class="n">update_context</span><span class="p">(</span><span class="n">stop_training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>
    203. <div class="viewcode-block" id="MissingMonitorKeyException"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.early_stopping.MissingMonitorKeyException">[docs]</a><span class="k">class</span> <span class="nc">MissingMonitorKeyException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
    204. <span class="sd">&quot;&quot;&quot;</span>
    205. <span class="sd"> Exception raised for missing monitor key in metrics_dict.</span>
    206. <span class="sd"> &quot;&quot;&quot;</span>
    207. <span class="k">pass</span></div>
    208. </pre></div>
    209. </div>
    210. </div>
    211. <footer>
    212. <hr/>
    213. <div role="contentinfo">
    214. <p>&#169; Copyright 2021, SuperGradients team.</p>
    215. </div>
    216. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    217. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    218. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    219. </footer>
    220. </div>
    221. </div>
    222. </section>
    223. </div>
    224. <script>
    225. jQuery(function () {
    226. SphinxRtdTheme.Navigation.enable(true);
    227. });
    228. </script>
    229. </body>
    230. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.ema &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.ema</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.ema</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">math</span>
    84. <span class="kn">import</span> <span class="nn">warnings</span>
    85. <span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
    86. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
    87. <span class="kn">import</span> <span class="nn">torch</span>
    88. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    89. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
    90. <span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
    91. <span class="kn">from</span> <span class="nn">super_gradients.training.models.kd_modules.kd_module</span> <span class="kn">import</span> <span class="n">KDModule</span>
    92. <div class="viewcode-block" id="copy_attr"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.copy_attr">[docs]</a><span class="k">def</span> <span class="nf">copy_attr</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">include</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="p">(),</span> <span class="n">exclude</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="p">()):</span>
    93. <span class="c1"># Copy attributes from b to a, options to only include [...] and to exclude [...]</span>
    94. <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">b</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    95. <span class="k">if</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">include</span><span class="p">)</span> <span class="ow">and</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">include</span><span class="p">)</span> <span class="ow">or</span> <span class="n">k</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;_&#39;</span><span class="p">)</span> <span class="ow">or</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">exclude</span><span class="p">:</span>
    96. <span class="k">continue</span>
    97. <span class="k">else</span><span class="p">:</span>
    98. <span class="nb">setattr</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></div>
    99. <div class="viewcode-block" id="ModelEMA"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.ModelEMA">[docs]</a><span class="k">class</span> <span class="nc">ModelEMA</span><span class="p">:</span>
    100. <span class="sd">&quot;&quot;&quot; Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models</span>
    101. <span class="sd"> Keep a moving average of everything in the model state_dict (parameters and buffers).</span>
    102. <span class="sd"> This is intended to allow functionality like</span>
    103. <span class="sd"> https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage</span>
    104. <span class="sd"> A smoothed version of the weights is necessary for some training schemes to perform well.</span>
    105. <span class="sd"> This class is sensitive where it is initialized in the sequence of model init,</span>
    106. <span class="sd"> GPU assignment and distributed training wrappers.</span>
    107. <span class="sd"> &quot;&quot;&quot;</span>
    108. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span> <span class="n">exp_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
    109. <span class="sd">&quot;&quot;&quot;</span>
    110. <span class="sd"> Init the EMA</span>
    111. <span class="sd"> :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by</span>
    112. <span class="sd"> IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE</span>
    113. <span class="sd"> AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.</span>
    114. <span class="sd"> :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value</span>
    115. <span class="sd"> until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)</span>
    116. <span class="sd"> :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to</span>
    117. <span class="sd"> its final value. beta=15 is ~40% of the training process.</span>
    118. <span class="sd"> &quot;&quot;&quot;</span>
    119. <span class="c1"># Create EMA</span>
    120. <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
    121. <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
    122. <span class="k">if</span> <span class="n">exp_activation</span><span class="p">:</span>
    123. <span class="bp">self</span><span class="o">.</span><span class="n">decay_function</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">decay</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">math</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span> <span class="o">*</span> <span class="n">beta</span><span class="p">))</span> <span class="c1"># decay exponential ramp (to help early epochs)</span>
    124. <span class="k">else</span><span class="p">:</span>
    125. <span class="bp">self</span><span class="o">.</span><span class="n">decay_function</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">decay</span> <span class="c1"># always return the same decay factor</span>
    126. <span class="sd">&quot;&quot;&quot;&quot;</span>
    127. <span class="sd"> we hold a list of model attributes (not wights and biases) which we would like to include in each</span>
    128. <span class="sd"> attribute update or exclude from each update. a SgModule declare these attribute using</span>
    129. <span class="sd"> get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule</span>
    130. <span class="sd"> all non-private (not starting with &#39;_&#39;) attributes will be updated (and only them).</span>
    131. <span class="sd"> &quot;&quot;&quot;</span>
    132. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">SgModule</span><span class="p">):</span>
    133. <span class="bp">self</span><span class="o">.</span><span class="n">include_attributes</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">get_include_attributes</span><span class="p">()</span>
    134. <span class="bp">self</span><span class="o">.</span><span class="n">exclude_attributes</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">get_exclude_attributes</span><span class="p">()</span>
    135. <span class="k">else</span><span class="p">:</span>
    136. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;Warning: EMA should be used with SgModule instance. All attributes of the model will be &quot;</span>
    137. <span class="s2">&quot;included in EMA&quot;</span><span class="p">)</span>
    138. <span class="bp">self</span><span class="o">.</span><span class="n">include_attributes</span> <span class="o">=</span> <span class="p">[]</span>
    139. <span class="bp">self</span><span class="o">.</span><span class="n">exclude_attributes</span> <span class="o">=</span> <span class="p">[]</span>
    140. <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    141. <span class="n">p</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
    142. <div class="viewcode-block" id="ModelEMA.update"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.ModelEMA.update">[docs]</a> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">training_percent</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
    143. <span class="sd">&quot;&quot;&quot;</span>
    144. <span class="sd"> Update the state of the EMA model.</span>
    145. <span class="sd"> :param model: current training model</span>
    146. <span class="sd"> :param training_percent: the percentage of the training process [0,1]. i.e 0.4 means 40% of the training have passed</span>
    147. <span class="sd"> &quot;&quot;&quot;</span>
    148. <span class="c1"># Update EMA parameters</span>
    149. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
    150. <span class="n">decay</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_function</span><span class="p">(</span><span class="n">training_percent</span><span class="p">)</span>
    151. <span class="k">for</span> <span class="n">ema_v</span><span class="p">,</span> <span class="n">model_v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span><span class="o">.</span><span class="n">values</span><span class="p">(),</span> <span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span><span class="o">.</span><span class="n">values</span><span class="p">()):</span>
    152. <span class="k">if</span> <span class="n">ema_v</span><span class="o">.</span><span class="n">dtype</span><span class="o">.</span><span class="n">is_floating_point</span><span class="p">:</span>
    153. <span class="n">ema_v</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">ema_v</span> <span class="o">*</span> <span class="n">decay</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">decay</span><span class="p">)</span> <span class="o">*</span> <span class="n">model_v</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></div>
    154. <div class="viewcode-block" id="ModelEMA.update_attr"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.ModelEMA.update_attr">[docs]</a> <span class="k">def</span> <span class="nf">update_attr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
    155. <span class="sd">&quot;&quot;&quot;</span>
    156. <span class="sd"> This function updates model attributes (not weight and biases) from original model to the ema model.</span>
    157. <span class="sd"> attributes of the original model, such as anchors and grids (of detection models), may be crucial to the</span>
    158. <span class="sd"> model operation and need to be updated.</span>
    159. <span class="sd"> If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with &#39;_&#39;)</span>
    160. <span class="sd"> attributes will be updated (and only them).</span>
    161. <span class="sd"> :param model: the source model</span>
    162. <span class="sd"> &quot;&quot;&quot;</span>
    163. <span class="n">copy_attr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">include_attributes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">exclude_attributes</span><span class="p">)</span></div></div>
    164. <div class="viewcode-block" id="KDModelEMA"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.KDModelEMA">[docs]</a><span class="k">class</span> <span class="nc">KDModelEMA</span><span class="p">(</span><span class="n">ModelEMA</span><span class="p">):</span>
    165. <span class="sd">&quot;&quot;&quot; Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models</span>
    166. <span class="sd"> Keep a moving average of everything in the model state_dict (parameters and buffers).</span>
    167. <span class="sd"> This is intended to allow functionality like</span>
    168. <span class="sd"> https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage</span>
    169. <span class="sd"> A smoothed version of the weights is necessary for some training schemes to perform well.</span>
    170. <span class="sd"> This class is sensitive where it is initialized in the sequence of model init,</span>
    171. <span class="sd"> GPU assignment and distributed training wrappers.</span>
    172. <span class="sd"> &quot;&quot;&quot;</span>
    173. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kd_model</span><span class="p">:</span> <span class="n">KDModule</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span> <span class="n">exp_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
    174. <span class="sd">&quot;&quot;&quot;</span>
    175. <span class="sd"> Init the EMA</span>
    176. <span class="sd"> :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by</span>
    177. <span class="sd"> IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE</span>
    178. <span class="sd"> AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.</span>
    179. <span class="sd"> :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value</span>
    180. <span class="sd"> until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)</span>
    181. <span class="sd"> :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to</span>
    182. <span class="sd"> its final value. beta=15 is ~40% of the training process.</span>
    183. <span class="sd"> &quot;&quot;&quot;</span>
    184. <span class="c1"># Only work on the student (we don&#39;t want to update and to have a duplicate of the teacher)</span>
    185. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">student</span><span class="p">),</span>
    186. <span class="n">decay</span><span class="o">=</span><span class="n">decay</span><span class="p">,</span>
    187. <span class="n">beta</span><span class="o">=</span><span class="n">beta</span><span class="p">,</span>
    188. <span class="n">exp_activation</span><span class="o">=</span><span class="n">exp_activation</span><span class="p">)</span>
    189. <span class="c1"># Overwrite current ema attribute with combination of the student model EMA (current self.ema)</span>
    190. <span class="c1"># with already the instantiated teacher, to have the final KD EMA</span>
    191. <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="n">KDModule</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
    192. <span class="n">student</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="p">,</span>
    193. <span class="n">teacher</span><span class="o">=</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">teacher</span><span class="p">,</span>
    194. <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">run_teacher_on_eval</span><span class="p">))</span></div>
    195. </pre></div>
    196. </div>
    197. </div>
    198. <footer>
    199. <hr/>
    200. <div role="contentinfo">
    201. <p>&#169; Copyright 2021, SuperGradients team.</p>
    202. </div>
    203. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    204. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    205. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    206. </footer>
    207. </div>
    208. </div>
    209. </section>
    210. </div>
    211. <script>
    212. jQuery(function () {
    213. SphinxRtdTheme.Navigation.enable(true);
    214. });
    215. </script>
    216. </body>
    217. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.export_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.export_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.export_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    84. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
    85. <div class="viewcode-block" id="fuse_conv_bn"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.export_utils.fuse_conv_bn">[docs]</a><span class="k">def</span> <span class="nf">fuse_conv_bn</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">replace_bn_with_identity</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    86. <span class="sd">&quot;&quot;&quot;</span>
    87. <span class="sd"> Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model</span>
    88. <span class="sd"> :param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed</span>
    89. <span class="sd"> :param model: the target model</span>
    90. <span class="sd"> :return: the number of fuses executed</span>
    91. <span class="sd"> &quot;&quot;&quot;</span>
    92. <span class="n">children</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">named_children</span><span class="p">())</span>
    93. <span class="n">counter</span> <span class="o">=</span> <span class="mi">0</span>
    94. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">children</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
    95. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">children</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">):</span>
    96. <span class="nb">setattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">fuse_conv_bn_eval</span><span class="p">(</span><span class="n">children</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">]))</span>
    97. <span class="k">if</span> <span class="n">replace_bn_with_identity</span><span class="p">:</span>
    98. <span class="nb">setattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">())</span>
    99. <span class="k">else</span><span class="p">:</span>
    100. <span class="nb">delattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
    101. <span class="n">counter</span> <span class="o">+=</span> <span class="mi">1</span>
    102. <span class="k">for</span> <span class="n">child_name</span><span class="p">,</span> <span class="n">child</span> <span class="ow">in</span> <span class="n">children</span><span class="p">:</span>
    103. <span class="n">counter</span> <span class="o">+=</span> <span class="n">fuse_conv_bn</span><span class="p">(</span><span class="n">child</span><span class="p">,</span> <span class="n">replace_bn_with_identity</span><span class="p">)</span>
    104. <span class="k">return</span> <span class="n">counter</span></div>
    105. </pre></div>
    106. </div>
    107. </div>
    108. <footer>
    109. <hr/>
    110. <div role="contentinfo">
    111. <p>&#169; Copyright 2021, SuperGradients team.</p>
    112. </div>
    113. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    114. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    115. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    116. </footer>
    117. </div>
    118. </div>
    119. </section>
    120. </div>
    121. <script>
    122. jQuery(function () {
    123. SphinxRtdTheme.Navigation.enable(true);
    124. });
    125. </script>
    126. </body>
    127. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.module_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.module_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.module_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
    84. <span class="kn">import</span> <span class="nn">copy</span>
    85. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
    86. <span class="kn">import</span> <span class="nn">torch</span>
    87. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    88. <div class="viewcode-block" id="MultiOutputModule"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.MultiOutputModule">[docs]</a><span class="k">class</span> <span class="nc">MultiOutputModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    89. <span class="sd">&quot;&quot;&quot;</span>
    90. <span class="sd"> This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract</span>
    91. <span class="sd"> multiple output from its inner modules on each forward call() (as a list of output tensors)</span>
    92. <span class="sd"> note: the default output of the wrapped module will not be added to the output list by default. To get</span>
    93. <span class="sd"> the default output in the outputs list, explicitly include its path in the @output_paths parameter</span>
    94. <span class="sd"> i.e.</span>
    95. <span class="sd"> for module:</span>
    96. <span class="sd"> Sequential(</span>
    97. <span class="sd"> (0): Sequential(</span>
    98. <span class="sd"> (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)</span>
    99. <span class="sd"> (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)</span>
    100. <span class="sd"> (2): ReLU6(inplace=True)</span>
    101. <span class="sd"> ) ===================================&gt;&gt;</span>
    102. <span class="sd"> (1): InvertedResidual(</span>
    103. <span class="sd"> (conv): Sequential(</span>
    104. <span class="sd"> (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)</span>
    105. <span class="sd"> (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)</span>
    106. <span class="sd"> (2): ReLU6(inplace=True) ===================================&gt;&gt;</span>
    107. <span class="sd"> (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)</span>
    108. <span class="sd"> (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)</span>
    109. <span class="sd"> )</span>
    110. <span class="sd"> )</span>
    111. <span class="sd"> )</span>
    112. <span class="sd"> and paths:</span>
    113. <span class="sd"> [0, [1, &#39;conv&#39;, 2]]</span>
    114. <span class="sd"> the output are marked with arrows</span>
    115. <span class="sd"> &quot;&quot;&quot;</span>
    116. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">output_paths</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">prune</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
    117. <span class="sd">&quot;&quot;&quot;</span>
    118. <span class="sd"> :param module: The wrapped container module</span>
    119. <span class="sd"> :param output_paths: a list of lists or keys containing the canonical paths to the outputs</span>
    120. <span class="sd"> i.e. [3, [4, &#39;conv&#39;, 5], 7] will extract outputs of layers 3, 7 and 4-&gt;conv-&gt;5</span>
    121. <span class="sd"> &quot;&quot;&quot;</span>
    122. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    123. <span class="bp">self</span><span class="o">.</span><span class="n">output_paths</span> <span class="o">=</span> <span class="n">output_paths</span>
    124. <span class="bp">self</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="s1">&#39;0&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">module</span>
    125. <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span> <span class="o">=</span> <span class="p">{}</span>
    126. <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">output_paths</span><span class="p">:</span>
    127. <span class="n">child</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_recursive</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span>
    128. <span class="n">child</span><span class="o">.</span><span class="n">register_forward_hook</span><span class="p">(</span><span class="n">hook</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">save_output_hook</span><span class="p">)</span>
    129. <span class="c1"># PRUNE THE MODULE TO SUPPORT ALL PROVIDED OUTPUT_PATHS BUT REMOVE ALL REDUNDANT LAYERS</span>
    130. <span class="k">if</span> <span class="n">prune</span><span class="p">:</span>
    131. <span class="bp">self</span><span class="o">.</span><span class="n">_prune</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">output_paths</span><span class="p">)</span>
    132. <div class="viewcode-block" id="MultiOutputModule.save_output_hook"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.MultiOutputModule.save_output_hook">[docs]</a> <span class="k">def</span> <span class="nf">save_output_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">output</span><span class="p">):</span>
    133. <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span><span class="p">[</span><span class="nb">input</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">device</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">output</span><span class="p">)</span></div>
    134. <div class="viewcode-block" id="MultiOutputModule.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.MultiOutputModule.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
    135. <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
    136. <span class="bp">self</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="s1">&#39;0&#39;</span><span class="p">](</span><span class="n">x</span><span class="p">)</span>
    137. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">]</span></div>
    138. <span class="k">def</span> <span class="nf">_get_recursive</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">:</span>
    139. <span class="sd">&quot;&quot;&quot;recursively look for a module using a path&quot;&quot;&quot;</span>
    140. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
    141. <span class="k">return</span> <span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">)]</span>
    142. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
    143. <span class="k">return</span> <span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
    144. <span class="k">else</span><span class="p">:</span>
    145. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_recursive</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">])],</span> <span class="n">path</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
    146. <span class="k">def</span> <span class="nf">_prune</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">output_paths</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
    147. <span class="sd">&quot;&quot;&quot;</span>
    148. <span class="sd"> Recursively prune the module to support all provided output_paths but remove all redundant layers</span>
    149. <span class="sd"> &quot;&quot;&quot;</span>
    150. <span class="n">last_index</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
    151. <span class="n">last_key</span> <span class="o">=</span> <span class="kc">None</span>
    152. <span class="c1"># look for the last key from all paths</span>
    153. <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">output_paths</span><span class="p">:</span>
    154. <span class="n">key</span> <span class="o">=</span> <span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">path</span>
    155. <span class="n">index</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">)</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">key</span><span class="p">))</span>
    156. <span class="k">if</span> <span class="n">index</span> <span class="o">&gt;</span> <span class="n">last_index</span><span class="p">:</span>
    157. <span class="n">last_index</span> <span class="o">=</span> <span class="n">index</span>
    158. <span class="n">last_key</span> <span class="o">=</span> <span class="n">key</span>
    159. <span class="n">module</span><span class="o">.</span><span class="n">_modules</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_slice_odict</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
    160. <span class="n">next_level_paths</span> <span class="o">=</span> <span class="p">[]</span>
    161. <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">output_paths</span><span class="p">:</span>
    162. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">last_key</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    163. <span class="n">next_level_paths</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">path</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
    164. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">next_level_paths</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    165. <span class="bp">self</span><span class="o">.</span><span class="n">_prune</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">last_key</span><span class="p">)],</span> <span class="n">next_level_paths</span><span class="p">)</span>
    166. <span class="k">def</span> <span class="nf">_slice_odict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">odict</span><span class="p">:</span> <span class="n">OrderedDict</span><span class="p">,</span> <span class="n">start</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">end</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    167. <span class="sd">&quot;&quot;&quot;Slice an OrderedDict in the same logic list,tuple... are sliced&quot;&quot;&quot;</span>
    168. <span class="k">return</span> <span class="n">OrderedDict</span><span class="p">([</span>
    169. <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> <span class="ow">in</span> <span class="n">odict</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
    170. <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">odict</span><span class="o">.</span><span class="n">keys</span><span class="p">())[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span>
    171. <span class="p">])</span></div>
    172. <span class="k">def</span> <span class="nf">_replace_activations_recursive</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">type</span><span class="p">]):</span>
    173. <span class="sd">&quot;&quot;&quot;</span>
    174. <span class="sd"> A helper called in replace_activations(...)</span>
    175. <span class="sd"> &quot;&quot;&quot;</span>
    176. <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">named_children</span><span class="p">():</span>
    177. <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">m</span><span class="p">)</span> <span class="ow">in</span> <span class="n">activations_to_replace</span><span class="p">:</span>
    178. <span class="nb">setattr</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">new_activation</span><span class="p">))</span>
    179. <span class="k">else</span><span class="p">:</span>
    180. <span class="n">_replace_activations_recursive</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">)</span>
    181. <div class="viewcode-block" id="replace_activations"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.replace_activations">[docs]</a><span class="k">def</span> <span class="nf">replace_activations</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">type</span><span class="p">]):</span>
    182. <span class="sd">&quot;&quot;&quot;</span>
    183. <span class="sd"> Recursively go through module and replaces each activation in activations_to_replace with a copy of new_activation</span>
    184. <span class="sd"> :param module: a module that will be changed inplace</span>
    185. <span class="sd"> :param new_activation: a sample of a new activation (will be copied)</span>
    186. <span class="sd"> :param activations_to_replace: types of activations to replace, each must be a subclass of nn.Module</span>
    187. <span class="sd"> &quot;&quot;&quot;</span>
    188. <span class="c1"># check arguments once before the recursion</span>
    189. <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">new_activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">),</span> <span class="s1">&#39;new_activation should be nn.Module&#39;</span>
    190. <span class="k">assert</span> <span class="nb">all</span><span class="p">([</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="nb">type</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">activations_to_replace</span><span class="p">]),</span> \
    191. <span class="s1">&#39;activations_to_replace should be types that are subclasses of nn.Module&#39;</span>
    192. <span class="c1"># do the replacement</span>
    193. <span class="n">_replace_activations_recursive</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">)</span></div>
    194. <div class="viewcode-block" id="fuse_repvgg_blocks_residual_branches"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.fuse_repvgg_blocks_residual_branches">[docs]</a><span class="k">def</span> <span class="nf">fuse_repvgg_blocks_residual_branches</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    195. <span class="sd">&#39;&#39;&#39;</span>
    196. <span class="sd"> Call fuse_block_residual_branches for all repvgg blocks in the model</span>
    197. <span class="sd"> :param model: torch.nn.Module with repvgg blocks. Doesn&#39;t have to be entirely consists of repvgg.</span>
    198. <span class="sd"> :type model: torch.nn.Module</span>
    199. <span class="sd"> &#39;&#39;&#39;</span>
    200. <span class="k">assert</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">training</span><span class="p">,</span> <span class="s2">&quot;To fuse RepVGG block residual branches, model must be on eval mode&quot;</span>
    201. <span class="k">for</span> <span class="n">module</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">():</span>
    202. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;fuse_block_residual_branches&#39;</span><span class="p">):</span>
    203. <span class="n">module</span><span class="o">.</span><span class="n">fuse_block_residual_branches</span><span class="p">()</span>
    204. <span class="n">model</span><span class="o">.</span><span class="n">build_residual_branches</span> <span class="o">=</span> <span class="kc">False</span></div>
    205. <div class="viewcode-block" id="ConvBNReLU"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.ConvBNReLU">[docs]</a><span class="k">class</span> <span class="nc">ConvBNReLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    206. <span class="sd">&quot;&quot;&quot;</span>
    207. <span class="sd"> Class for Convolution2d-Batchnorm2d-Relu layer. Default behaviour is Conv-BN-Relu. To exclude Batchnorm module use</span>
    208. <span class="sd"> `use_normalization=False`, to exclude Relu activation use `use_activation=False`.</span>
    209. <span class="sd"> For convolution arguments documentation see `nn.Conv2d`.</span>
    210. <span class="sd"> For batchnorm arguments documentation see `nn.BatchNorm2d`.</span>
    211. <span class="sd"> For relu arguments documentation see `nn.Relu`.</span>
    212. <span class="sd"> &quot;&quot;&quot;</span>
    213. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
    214. <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    215. <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    216. <span class="n">kernel_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span>
    217. <span class="n">stride</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
    218. <span class="n">padding</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
    219. <span class="n">dilation</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
    220. <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
    221. <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    222. <span class="n">padding_mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;zeros&#39;</span><span class="p">,</span>
    223. <span class="n">use_normalization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    224. <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span>
    225. <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    226. <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    227. <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    228. <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    229. <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
    230. <span class="n">use_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
    231. <span class="n">inplace</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
    232. <span class="nb">super</span><span class="p">(</span><span class="n">ConvBNReLU</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    233. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">()</span>
    234. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="o">.</span><span class="n">add_module</span><span class="p">(</span><span class="s2">&quot;conv&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span>
    235. <span class="n">out_channels</span><span class="p">,</span>
    236. <span class="n">kernel_size</span><span class="o">=</span><span class="n">kernel_size</span><span class="p">,</span>
    237. <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span>
    238. <span class="n">padding</span><span class="o">=</span><span class="n">padding</span><span class="p">,</span>
    239. <span class="n">dilation</span><span class="o">=</span><span class="n">dilation</span><span class="p">,</span>
    240. <span class="n">groups</span><span class="o">=</span><span class="n">groups</span><span class="p">,</span>
    241. <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
    242. <span class="n">padding_mode</span><span class="o">=</span><span class="n">padding_mode</span><span class="p">))</span>
    243. <span class="k">if</span> <span class="n">use_normalization</span><span class="p">:</span>
    244. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="o">.</span><span class="n">add_module</span><span class="p">(</span><span class="s2">&quot;bn&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="n">affine</span><span class="p">,</span>
    245. <span class="n">track_running_stats</span><span class="o">=</span><span class="n">track_running_stats</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
    246. <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
    247. <span class="k">if</span> <span class="n">use_activation</span><span class="p">:</span>
    248. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="o">.</span><span class="n">add_module</span><span class="p">(</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="n">inplace</span><span class="p">))</span>
    249. <div class="viewcode-block" id="ConvBNReLU.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.ConvBNReLU.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    250. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div></div>
    251. <div class="viewcode-block" id="NormalizationAdapter"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.NormalizationAdapter">[docs]</a><span class="k">class</span> <span class="nc">NormalizationAdapter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    252. <span class="sd">&quot;&quot;&quot;</span>
    253. <span class="sd"> Denormalizes input by mean_original, std_original, then normalizes by mean_required, std_required.</span>
    254. <span class="sd"> Used in KD training where teacher expects data normalized by mean_required, std_required.</span>
    255. <span class="sd"> mean_original, std_original, mean_required, std_required are all list-like objects of length that&#39;s equal to the</span>
    256. <span class="sd"> number of input channels.</span>
    257. <span class="sd"> &quot;&quot;&quot;</span>
    258. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean_original</span><span class="p">,</span> <span class="n">std_original</span><span class="p">,</span> <span class="n">mean_required</span><span class="p">,</span> <span class="n">std_required</span><span class="p">):</span>
    259. <span class="nb">super</span><span class="p">(</span><span class="n">NormalizationAdapter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    260. <span class="n">mean_original</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">mean_original</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    261. <span class="n">std_original</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">std_original</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    262. <span class="n">mean_required</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">mean_required</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    263. <span class="n">std_required</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">std_required</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    264. <span class="bp">self</span><span class="o">.</span><span class="n">additive</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">((</span><span class="n">mean_original</span> <span class="o">-</span> <span class="n">mean_required</span><span class="p">)</span> <span class="o">/</span> <span class="n">std_original</span><span class="p">)</span>
    265. <span class="bp">self</span><span class="o">.</span><span class="n">multiplier</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">std_original</span> <span class="o">/</span> <span class="n">std_required</span><span class="p">)</span>
    266. <div class="viewcode-block" id="NormalizationAdapter.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.NormalizationAdapter.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    267. <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">additive</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiplier</span>
    268. <span class="k">return</span> <span class="n">x</span></div></div>
    269. </pre></div>
    270. </div>
    271. </div>
    272. <footer>
    273. <hr/>
    274. <div role="contentinfo">
    275. <p>&#169; Copyright 2021, SuperGradients team.</p>
    276. </div>
    277. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    278. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    279. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    280. </footer>
    281. </div>
    282. </div>
    283. </section>
    284. </div>
    285. <script>
    286. jQuery(function () {
    287. SphinxRtdTheme.Navigation.enable(true);
    288. });
    289. </script>
    290. </body>
    291. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.optimizer_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.optimizer_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.optimizer_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="nn">optim</span>
    84. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
    85. <span class="kn">from</span> <span class="nn">torch.nn.modules.batchnorm</span> <span class="kn">import</span> <span class="n">_BatchNorm</span>
    86. <span class="kn">from</span> <span class="nn">torch.nn.modules.conv</span> <span class="kn">import</span> <span class="n">_ConvNd</span>
    87. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
    88. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.optimizers_type_factory</span> <span class="kn">import</span> <span class="n">OptimizersTypeFactory</span>
    89. <span class="kn">from</span> <span class="nn">super_gradients.training.params</span> <span class="kn">import</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_SGD</span><span class="p">,</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_ADAM</span><span class="p">,</span> \
    90. <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROP</span><span class="p">,</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF</span>
    91. <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
    92. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.optimizers.rmsprop_tf</span> <span class="kn">import</span> <span class="n">RMSpropTF</span>
    93. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
    94. <span class="n">OPTIMIZERS_DEFAULT_PARAMS</span> <span class="o">=</span> <span class="p">{</span><span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_SGD</span><span class="p">,</span>
    95. <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_ADAM</span><span class="p">,</span>
    96. <span class="n">optim</span><span class="o">.</span><span class="n">RMSprop</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROP</span><span class="p">,</span>
    97. <span class="n">RMSpropTF</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF</span><span class="p">}</span>
    98. <div class="viewcode-block" id="separate_zero_wd_params_groups_for_optimizer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.optimizer_utils.separate_zero_wd_params_groups_for_optimizer">[docs]</a><span class="k">def</span> <span class="nf">separate_zero_wd_params_groups_for_optimizer</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">net_named_params</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
    99. <span class="sd">&quot;&quot;&quot;</span>
    100. <span class="sd"> separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format</span>
    101. <span class="sd"> required by torch Optimizer classes.</span>
    102. <span class="sd"> bias + BN with weight decay=0 and the rest with the given weight decay</span>
    103. <span class="sd"> :param module: train net module.</span>
    104. <span class="sd"> :param net_named_params: list of params groups, output of SgModule.initialize_param_groups</span>
    105. <span class="sd"> :param weight_decay: value to set for the non BN and bias parameters</span>
    106. <span class="sd"> &quot;&quot;&quot;</span>
    107. <span class="c1"># FIXME - replace usage of ids addresses to find batchnorm and biases params.</span>
    108. <span class="c1"># This solution iterate 2 times over module parameters, find a way to iterate only one time.</span>
    109. <span class="n">no_decay_ids</span> <span class="o">=</span> <span class="n">_get_no_decay_param_ids</span><span class="p">(</span><span class="n">module</span><span class="p">)</span>
    110. <span class="c1"># split param groups for optimizer</span>
    111. <span class="n">optimizer_param_groups</span> <span class="o">=</span> <span class="p">[]</span>
    112. <span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">net_named_params</span><span class="p">:</span>
    113. <span class="n">no_decay_params</span> <span class="o">=</span> <span class="p">[]</span>
    114. <span class="n">decay_params</span> <span class="o">=</span> <span class="p">[]</span>
    115. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">param_group</span><span class="p">[</span><span class="s2">&quot;named_params&quot;</span><span class="p">]:</span>
    116. <span class="k">if</span> <span class="nb">id</span><span class="p">(</span><span class="n">param</span><span class="p">)</span> <span class="ow">in</span> <span class="n">no_decay_ids</span><span class="p">:</span>
    117. <span class="n">no_decay_params</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
    118. <span class="k">else</span><span class="p">:</span>
    119. <span class="n">decay_params</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
    120. <span class="c1"># append two param groups from the original param group, with and without weight decay.</span>
    121. <span class="n">extra_optim_params</span> <span class="o">=</span> <span class="p">{</span><span class="n">key</span><span class="p">:</span> <span class="n">param_group</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">param_group</span>
    122. <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;named_params&quot;</span><span class="p">,</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">]}</span>
    123. <span class="n">optimizer_param_groups</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="n">no_decay_params</span><span class="p">,</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span> <span class="o">**</span><span class="n">extra_optim_params</span><span class="p">})</span>
    124. <span class="n">optimizer_param_groups</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="n">decay_params</span><span class="p">,</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="o">**</span><span class="n">extra_optim_params</span><span class="p">})</span>
    125. <span class="k">return</span> <span class="n">optimizer_param_groups</span></div>
    126. <span class="k">def</span> <span class="nf">_get_no_decay_param_ids</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    127. <span class="c1"># FIXME - replace usage of ids addresses to find batchnorm and biases params.</span>
    128. <span class="c1"># Use other common way to identify torch parameters other than id or layer names</span>
    129. <span class="sd">&quot;&quot;&quot;</span>
    130. <span class="sd"> Iterate over module.modules() and returns params id addresses of batch-norm and biases params.</span>
    131. <span class="sd"> NOTE - ALL MODULES WITH ATTRIBUTES NAMED BIAS AND ARE INSTANCE OF nn.Parameter WILL BE CONSIDERED A BIAS PARAM FOR</span>
    132. <span class="sd"> ZERO WEIGHT DECAY.</span>
    133. <span class="sd"> &quot;&quot;&quot;</span>
    134. <span class="n">batchnorm_types</span> <span class="o">=</span> <span class="p">(</span><span class="n">_BatchNorm</span><span class="p">,)</span>
    135. <span class="n">torch_weight_with_bias_types</span> <span class="o">=</span> <span class="p">(</span><span class="n">_ConvNd</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">)</span>
    136. <span class="n">no_decay_ids</span> <span class="o">=</span> <span class="p">[]</span>
    137. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">named_modules</span><span class="p">():</span>
    138. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">batchnorm_types</span><span class="p">):</span>
    139. <span class="n">no_decay_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">id</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">))</span>
    140. <span class="n">no_decay_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">id</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">))</span>
    141. <span class="k">elif</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="s2">&quot;bias&quot;</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span>
    142. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">torch_weight_with_bias_types</span><span class="p">):</span>
    143. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Module class: </span><span class="si">{</span><span class="n">m</span><span class="o">.</span><span class="vm">__class__</span><span class="si">}</span><span class="s2">, have a `bias` parameter attribute but is not instance of&quot;</span>
    144. <span class="sa">f</span><span class="s2">&quot; torch primitive modules, this bias parameter will be part of param group with zero&quot;</span>
    145. <span class="sa">f</span><span class="s2">&quot; weight decay.&quot;</span><span class="p">)</span>
    146. <span class="n">no_decay_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">id</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">))</span>
    147. <span class="k">return</span> <span class="n">no_decay_ids</span>
    148. <div class="viewcode-block" id="build_optimizer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.optimizer_utils.build_optimizer">[docs]</a><span class="k">def</span> <span class="nf">build_optimizer</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">training_params</span><span class="p">):</span>
    149. <span class="sd">&quot;&quot;&quot;</span>
    150. <span class="sd"> Wrapper function for initializing the optimizer</span>
    151. <span class="sd"> :param net: the nn_module to build the optimizer for</span>
    152. <span class="sd"> :param lr: initial learning rate</span>
    153. <span class="sd"> :param training_params: training_parameters</span>
    154. <span class="sd"> &quot;&quot;&quot;</span>
    155. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">training_params</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
    156. <span class="n">optimizer_cls</span> <span class="o">=</span> <span class="n">OptimizersTypeFactory</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">training_params</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
    157. <span class="k">else</span><span class="p">:</span>
    158. <span class="n">optimizer_cls</span> <span class="o">=</span> <span class="n">training_params</span><span class="o">.</span><span class="n">optimizer</span>
    159. <span class="n">default_optimizer_params</span> <span class="o">=</span> <span class="n">OPTIMIZERS_DEFAULT_PARAMS</span><span class="p">[</span><span class="n">optimizer_cls</span><span class="p">]</span> <span class="k">if</span> <span class="n">optimizer_cls</span> <span class="ow">in</span> <span class="n">OPTIMIZERS_DEFAULT_PARAMS</span> <span class="k">else</span> <span class="p">{}</span>
    160. <span class="n">training_params</span><span class="o">.</span><span class="n">optimizer_params</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">training_params</span><span class="p">,</span> <span class="s1">&#39;optimizer_params&#39;</span><span class="p">,</span> <span class="n">default_optimizer_params</span><span class="p">)</span>
    161. <span class="c1"># OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT</span>
    162. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;initialize_param_groups&#39;</span><span class="p">):</span>
    163. <span class="c1"># INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH &#39;named_params&#39; AND OPTIMIZER&#39;s ATTRIBUTES PER GROUP</span>
    164. <span class="n">net_named_params</span> <span class="o">=</span> <span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">initialize_param_groups</span><span class="p">(</span><span class="n">lr</span><span class="p">,</span> <span class="n">training_params</span><span class="p">)</span>
    165. <span class="k">else</span><span class="p">:</span>
    166. <span class="n">net_named_params</span> <span class="o">=</span> <span class="p">[{</span><span class="s1">&#39;named_params&#39;</span><span class="p">:</span> <span class="n">net</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()}]</span>
    167. <span class="k">if</span> <span class="n">training_params</span><span class="o">.</span><span class="n">zero_weight_decay_on_bias_and_bn</span><span class="p">:</span>
    168. <span class="n">optimizer_training_params</span> <span class="o">=</span> <span class="n">separate_zero_wd_params_groups_for_optimizer</span><span class="p">(</span>
    169. <span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">net_named_params</span><span class="p">,</span> <span class="n">training_params</span><span class="o">.</span><span class="n">optimizer_params</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">]</span>
    170. <span class="p">)</span>
    171. <span class="k">else</span><span class="p">:</span>
    172. <span class="c1"># Overwrite groups to include params instead of named params</span>
    173. <span class="k">for</span> <span class="n">ind_group</span><span class="p">,</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">net_named_params</span><span class="p">):</span>
    174. <span class="n">param_group</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">param</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">param_group</span><span class="p">[</span><span class="s1">&#39;named_params&#39;</span><span class="p">])]</span>
    175. <span class="k">del</span> <span class="n">param_group</span><span class="p">[</span><span class="s1">&#39;named_params&#39;</span><span class="p">]</span>
    176. <span class="n">net_named_params</span><span class="p">[</span><span class="n">ind_group</span><span class="p">]</span> <span class="o">=</span> <span class="n">param_group</span>
    177. <span class="n">optimizer_training_params</span> <span class="o">=</span> <span class="n">net_named_params</span>
    178. <span class="c1"># CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT</span>
    179. <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer_cls</span><span class="p">(</span><span class="n">optimizer_training_params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="o">**</span><span class="n">training_params</span><span class="o">.</span><span class="n">optimizer_params</span><span class="p">)</span>
    180. <span class="k">return</span> <span class="n">optimizer</span></div>
    181. </pre></div>
    182. </div>
    183. </div>
    184. <footer>
    185. <hr/>
    186. <div role="contentinfo">
    187. <p>&#169; Copyright 2021, SuperGradients team.</p>
    188. </div>
    189. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    190. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    191. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    192. </footer>
    193. </div>
    194. </div>
    195. </section>
    196. </div>
    197. <script>
    198. jQuery(function () {
    199. SphinxRtdTheme.Navigation.enable(true);
    200. });
    201. </script>
    202. </body>
    203. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.optimizers.rmsprop_tf &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../../_static/jquery.js"></script>
    15. <script src="../../../../../_static/underscore.js"></script>
    16. <script src="../../../../../_static/doctools.js"></script>
    17. <script src="../../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.optimizers.rmsprop_tf</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.optimizers.rmsprop_tf</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    84. <span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span>
    85. <span class="sd">&quot;&quot;&quot;</span>
    86. <span class="sd">This implementation is taken from timm&#39;s github:</span>
    87. <span class="sd">https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py</span>
    88. <span class="sd">&quot;&quot;&quot;</span>
    89. <span class="sd">&quot;&quot;&quot; RMSProp modified to behave like Tensorflow impl</span>
    90. <span class="sd">Originally cut &amp; paste from PyTorch RMSProp</span>
    91. <span class="sd">https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py</span>
    92. <span class="sd">Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE</span>
    93. <span class="sd">Modifications Copyright 2020 Ross Wightman</span>
    94. <span class="sd">&quot;&quot;&quot;</span>
    95. <div class="viewcode-block" id="RMSpropTF"><a class="viewcode-back" href="../../../../../super_gradients.training.utils.optimizers.html#super_gradients.training.utils.optimizers.rmsprop_tf.RMSpropTF">[docs]</a><span class="k">class</span> <span class="nc">RMSpropTF</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
    96. <span class="sd">&quot;&quot;&quot;Implements RMSprop algorithm (TensorFlow style epsilon)</span>
    97. <span class="sd"> NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt</span>
    98. <span class="sd"> and a few other modifications to closer match Tensorflow for matching hyper-params.</span>
    99. <span class="sd"> Noteworthy changes include:</span>
    100. <span class="sd"> 1. Epsilon applied inside square-root</span>
    101. <span class="sd"> 2. square_avg initialized to ones</span>
    102. <span class="sd"> 3. LR scaling of update accumulated in momentum buffer</span>
    103. <span class="sd"> Proposed by G. Hinton in his</span>
    104. <span class="sd"> `course &lt;http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf&gt;`_.</span>
    105. <span class="sd"> The centered version first appears in `Generating Sequences</span>
    106. <span class="sd"> With Recurrent Neural Networks &lt;https://arxiv.org/pdf/1308.0850v5.pdf&gt;`_.&quot;&quot;&quot;</span>
    107. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-10</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">centered</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    108. <span class="n">decoupled_decay</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">lr_in_momentum</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    109. <span class="sd">&quot;&quot;&quot;RMSprop optimizer that follows the tf&#39;s RMSprop characteristics</span>
    110. <span class="sd"> :param params (iterable): iterable of parameters to optimize or dicts defining parameter groups</span>
    111. <span class="sd"> :param lr (float, optional): learning rate</span>
    112. <span class="sd"> :param momentum (float, optional): momentum factor</span>
    113. <span class="sd"> :param alpha (float, optional): smoothing (decay) constant</span>
    114. <span class="sd"> :param eps (float, optional): term added to the denominator to improve numerical stability</span>
    115. <span class="sd"> :param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an</span>
    116. <span class="sd"> estimation of its variance</span>
    117. <span class="sd"> :param weight_decay (float, optional): weight decay (L2 penalty)</span>
    118. <span class="sd"> :param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101</span>
    119. <span class="sd"> :param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per</span>
    120. <span class="sd"> defaults in Tensorflow</span>
    121. <span class="sd"> &quot;&quot;&quot;</span>
    122. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">lr</span><span class="p">:</span>
    123. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid learning rate: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">lr</span><span class="p">))</span>
    124. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">eps</span><span class="p">:</span>
    125. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid epsilon value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">eps</span><span class="p">))</span>
    126. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">momentum</span><span class="p">:</span>
    127. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid momentum value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">momentum</span><span class="p">))</span>
    128. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">weight_decay</span><span class="p">:</span>
    129. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid weight_decay value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight_decay</span><span class="p">))</span>
    130. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">alpha</span><span class="p">:</span>
    131. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid alpha value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">alpha</span><span class="p">))</span>
    132. <span class="n">defaults</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">centered</span><span class="o">=</span><span class="n">centered</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">,</span>
    133. <span class="n">decoupled_decay</span><span class="o">=</span><span class="n">decoupled_decay</span><span class="p">,</span> <span class="n">lr_in_momentum</span><span class="o">=</span><span class="n">lr_in_momentum</span><span class="p">)</span>
    134. <span class="nb">super</span><span class="p">(</span><span class="n">RMSpropTF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span>
    135. <span class="k">def</span> <span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
    136. <span class="nb">super</span><span class="p">(</span><span class="n">RMSpropTF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__setstate__</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
    137. <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
    138. <span class="n">group</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s1">&#39;momentum&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    139. <span class="n">group</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s1">&#39;centered&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
    140. <div class="viewcode-block" id="RMSpropTF.step"><a class="viewcode-back" href="../../../../../super_gradients.training.utils.optimizers.html#super_gradients.training.utils.optimizers.rmsprop_tf.RMSpropTF.step">[docs]</a> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="c1"># noqa: C901</span>
    141. <span class="sd">&quot;&quot;&quot;Performs a single optimization step.</span>
    142. <span class="sd"> Arguments:</span>
    143. <span class="sd"> closure (callable, optional): A closure that reevaluates the model</span>
    144. <span class="sd"> and returns the loss.</span>
    145. <span class="sd"> &quot;&quot;&quot;</span>
    146. <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span>
    147. <span class="k">if</span> <span class="n">closure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    148. <span class="n">loss</span> <span class="o">=</span> <span class="n">closure</span><span class="p">()</span>
    149. <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
    150. <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]:</span>
    151. <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    152. <span class="k">continue</span>
    153. <span class="n">grad</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span>
    154. <span class="k">if</span> <span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
    155. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;RMSprop does not support sparse gradients&#39;</span><span class="p">)</span>
    156. <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="n">p</span><span class="p">]</span>
    157. <span class="c1"># State initialization</span>
    158. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    159. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
    160. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;square_avg&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="c1"># PyTorch inits to zero</span>
    161. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    162. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;momentum_buffer&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
    163. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;centered&#39;</span><span class="p">]:</span>
    164. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;grad_avg&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
    165. <span class="n">square_avg</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;square_avg&#39;</span><span class="p">]</span>
    166. <span class="n">one_minus_alpha</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;alpha&#39;</span><span class="p">]</span>
    167. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span>
    168. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
    169. <span class="k">if</span> <span class="s1">&#39;decoupled_decay&#39;</span> <span class="ow">in</span> <span class="n">group</span> <span class="ow">and</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;decoupled_decay&#39;</span><span class="p">]:</span>
    170. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="o">-</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">],</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
    171. <span class="k">else</span><span class="p">:</span>
    172. <span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">],</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
    173. <span class="c1"># Tensorflow order of ops for updating squared avg</span>
    174. <span class="n">square_avg</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">one_minus_alpha</span><span class="p">,</span> <span class="n">grad</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">square_avg</span><span class="p">)</span>
    175. <span class="c1"># square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original</span>
    176. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;centered&#39;</span><span class="p">]:</span>
    177. <span class="n">grad_avg</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;grad_avg&#39;</span><span class="p">]</span>
    178. <span class="n">grad_avg</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">one_minus_alpha</span><span class="p">,</span> <span class="n">grad</span> <span class="o">-</span> <span class="n">grad_avg</span><span class="p">)</span>
    179. <span class="c1"># grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original</span>
    180. <span class="n">avg</span> <span class="o">=</span> <span class="n">square_avg</span><span class="o">.</span><span class="n">addcmul</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">grad_avg</span><span class="p">,</span> <span class="n">grad_avg</span><span class="p">)</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;eps&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">()</span> <span class="c1"># eps moved in sqrt</span>
    181. <span class="k">else</span><span class="p">:</span>
    182. <span class="n">avg</span> <span class="o">=</span> <span class="n">square_avg</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;eps&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">()</span> <span class="c1"># eps moved in sqrt</span>
    183. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    184. <span class="n">buf</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;momentum_buffer&#39;</span><span class="p">]</span>
    185. <span class="c1"># Tensorflow accumulates the LR scaling in the momentum buffer</span>
    186. <span class="k">if</span> <span class="s1">&#39;lr_in_momentum&#39;</span> <span class="ow">in</span> <span class="n">group</span> <span class="ow">and</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr_in_momentum&#39;</span><span class="p">]:</span>
    187. <span class="n">buf</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">addcdiv_</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">],</span> <span class="n">grad</span><span class="p">,</span> <span class="n">avg</span><span class="p">)</span>
    188. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="o">-</span><span class="n">buf</span><span class="p">)</span>
    189. <span class="k">else</span><span class="p">:</span>
    190. <span class="c1"># PyTorch scales the param update by LR</span>
    191. <span class="n">buf</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">addcdiv_</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">avg</span><span class="p">)</span>
    192. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="o">-</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">],</span> <span class="n">buf</span><span class="p">)</span>
    193. <span class="k">else</span><span class="p">:</span>
    194. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">addcdiv_</span><span class="p">(</span><span class="o">-</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">],</span> <span class="n">grad</span><span class="p">,</span> <span class="n">avg</span><span class="p">)</span>
    195. <span class="k">return</span> <span class="n">loss</span></div></div>
    196. </pre></div>
    197. </div>
    198. </div>
    199. <footer>
    200. <hr/>
    201. <div role="contentinfo">
    202. <p>&#169; Copyright 2021, SuperGradients team.</p>
    203. </div>
    204. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    205. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    206. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    207. </footer>
    208. </div>
    209. </div>
    210. </section>
    211. </div>
    212. <script>
    213. jQuery(function () {
    214. SphinxRtdTheme.Navigation.enable(true);
    215. });
    216. </script>
    217. </body>
    218. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.regularization_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.regularization_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.regularization_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
    85. <div class="viewcode-block" id="DropPath"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.regularization_utils.DropPath">[docs]</a><span class="k">class</span> <span class="nc">DropPath</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    86. <span class="sd">&quot;&quot;&quot;</span>
    87. <span class="sd"> Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).</span>
    88. <span class="sd"> Code taken from TIMM (https://github.com/rwightman/pytorch-image-models)</span>
    89. <span class="sd"> Apache License 2.0</span>
    90. <span class="sd"> &quot;&quot;&quot;</span>
    91. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">drop_prob</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    92. <span class="nb">super</span><span class="p">(</span><span class="n">DropPath</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    93. <span class="bp">self</span><span class="o">.</span><span class="n">drop_prob</span> <span class="o">=</span> <span class="n">drop_prob</span>
    94. <div class="viewcode-block" id="DropPath.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.regularization_utils.DropPath.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    95. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_prob</span> <span class="o">==</span> <span class="mf">0.</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
    96. <span class="k">return</span> <span class="n">x</span>
    97. <span class="n">keep_prob</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_prob</span>
    98. <span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],)</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># work with diff dim tensors, not just 2D ConvNets</span>
    99. <span class="n">random_tensor</span> <span class="o">=</span> <span class="n">keep_prob</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    100. <span class="n">random_tensor</span><span class="o">.</span><span class="n">floor_</span><span class="p">()</span> <span class="c1"># binarize</span>
    101. <span class="n">output</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">div</span><span class="p">(</span><span class="n">keep_prob</span><span class="p">)</span> <span class="o">*</span> <span class="n">random_tensor</span>
    102. <span class="k">return</span> <span class="n">output</span></div></div>
    103. </pre></div>
    104. </div>
    105. </div>
    106. <footer>
    107. <hr/>
    108. <div role="contentinfo">
    109. <p>&#169; Copyright 2021, SuperGradients team.</p>
    110. </div>
    111. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    112. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    113. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    114. </footer>
    115. </div>
    116. </div>
    117. </section>
    118. </div>
    119. <script>
    120. jQuery(function () {
    121. SphinxRtdTheme.Navigation.enable(true);
    122. });
    123. </script>
    124. </body>
    125. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.segmentation_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.segmentation_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.segmentation_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">os</span>
    84. <span class="kn">import</span> <span class="nn">cv2</span>
    85. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    86. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Callable</span>
    87. <span class="kn">import</span> <span class="nn">torch</span>
    88. <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
    89. <span class="kn">from</span> <span class="nn">torchvision.utils</span> <span class="kn">import</span> <span class="n">draw_segmentation_masks</span>
    90. <span class="c1"># FIXME: REFACTOR AUGMENTATIONS, CONSIDER USING A MORE EFFICIENT LIBRARIES SUCH AS, IMGAUG, DALI ETC.</span>
    91. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
    92. <div class="viewcode-block" id="coco_sub_classes_inclusion_tuples_list"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.coco_sub_classes_inclusion_tuples_list">[docs]</a><span class="k">def</span> <span class="nf">coco_sub_classes_inclusion_tuples_list</span><span class="p">():</span>
    93. <span class="k">return</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="s1">&#39;background&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="s1">&#39;airplane&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="s1">&#39;bird&#39;</span><span class="p">),</span>
    94. <span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="s1">&#39;boat&#39;</span><span class="p">),</span>
    95. <span class="p">(</span><span class="mi">44</span><span class="p">,</span> <span class="s1">&#39;bottle&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="s1">&#39;bus&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="s1">&#39;car&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="s1">&#39;cat&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">62</span><span class="p">,</span> <span class="s1">&#39;chair&#39;</span><span class="p">),</span>
    96. <span class="p">(</span><span class="mi">21</span><span class="p">,</span> <span class="s1">&#39;cow&#39;</span><span class="p">),</span>
    97. <span class="p">(</span><span class="mi">67</span><span class="p">,</span> <span class="s1">&#39;dining table&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">18</span><span class="p">,</span> <span class="s1">&#39;dog&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">19</span><span class="p">,</span> <span class="s1">&#39;horse&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="s1">&#39;motorcycle&#39;</span><span class="p">),</span>
    98. <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;person&#39;</span><span class="p">),</span>
    99. <span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;potted plant&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="s1">&#39;sheep&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">63</span><span class="p">,</span> <span class="s1">&#39;couch&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">),</span>
    100. <span class="p">(</span><span class="mi">72</span><span class="p">,</span> <span class="s1">&#39;tv&#39;</span><span class="p">)]</span></div>
    101. <div class="viewcode-block" id="to_one_hot"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.to_one_hot">[docs]</a><span class="k">def</span> <span class="nf">to_one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ignore_index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    102. <span class="sd">&quot;&quot;&quot;</span>
    103. <span class="sd"> Target label to one_hot tensor. labels and ignore_index must be consecutive numbers.</span>
    104. <span class="sd"> :param target: Class labels long tensor, with shape [N, H, W]</span>
    105. <span class="sd"> :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot</span>
    106. <span class="sd"> result.</span>
    107. <span class="sd"> :return: one hot tensor with shape [N, num_classes, H, W]</span>
    108. <span class="sd"> &quot;&quot;&quot;</span>
    109. <span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span> <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">num_classes</span> <span class="o">+</span> <span class="mi">1</span>
    110. <span class="n">one_hot</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
    111. <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    112. <span class="c1"># remove ignore_index channel</span>
    113. <span class="n">one_hot</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">one_hot</span><span class="p">[:,</span> <span class="p">:</span><span class="n">ignore_index</span><span class="p">],</span> <span class="n">one_hot</span><span class="p">[:,</span> <span class="n">ignore_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    114. <span class="k">return</span> <span class="n">one_hot</span></div>
    115. <div class="viewcode-block" id="reverse_imagenet_preprocessing"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.reverse_imagenet_preprocessing">[docs]</a><span class="k">def</span> <span class="nf">reverse_imagenet_preprocessing</span><span class="p">(</span><span class="n">im_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
    116. <span class="sd">&quot;&quot;&quot;</span>
    117. <span class="sd"> :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)</span>
    118. <span class="sd"> :return: images in a batch in cv2 format, BGR, (B, H, W, C)</span>
    119. <span class="sd"> &quot;&quot;&quot;</span>
    120. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_tensor</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    121. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_np</span><span class="p">[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    122. <span class="n">im_np</span> <span class="o">*=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[[</span><span class="mf">.229</span><span class="p">,</span> <span class="mf">.224</span><span class="p">,</span> <span class="mf">.225</span><span class="p">][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]]])</span>
    123. <span class="n">im_np</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[[</span><span class="mf">.485</span><span class="p">,</span> <span class="mf">.456</span><span class="p">,</span> <span class="mf">.406</span><span class="p">][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]]])</span>
    124. <span class="n">im_np</span> <span class="o">*=</span> <span class="mf">255.</span>
    125. <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">im_np</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span></div>
    126. <div class="viewcode-block" id="BinarySegmentationVisualization"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.BinarySegmentationVisualization">[docs]</a><span class="k">class</span> <span class="nc">BinarySegmentationVisualization</span><span class="p">:</span>
    127. <span class="nd">@staticmethod</span>
    128. <span class="k">def</span> <span class="nf">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    129. <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">image_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    130. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">pred_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    131. <span class="n">image_np</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">moveaxis</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">))</span>
    132. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">pred_mask</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span> <span class="o">&gt;</span> <span class="mf">0.5</span>
    133. <span class="n">target_mask</span> <span class="o">=</span> <span class="n">target_mask</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">bool</span><span class="p">)</span>
    134. <span class="n">tp_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">,</span> <span class="n">target_mask</span><span class="p">)</span>
    135. <span class="n">fp_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">target_mask</span><span class="p">))</span>
    136. <span class="n">fn_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">),</span> <span class="n">target_mask</span><span class="p">)</span>
    137. <span class="n">overlay</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">tp_mask</span><span class="p">,</span> <span class="n">fp_mask</span><span class="p">,</span> <span class="n">fn_mask</span><span class="p">]))</span>
    138. <span class="c1"># SWITCH BETWEEN BLUE AND RED IF WE SAVE THE IMAGE ON THE DISC AS OTHERWISE WE CHANGE CHANNEL ORDERING</span>
    139. <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;red&#39;</span><span class="p">,</span> <span class="s1">&#39;blue&#39;</span><span class="p">]</span>
    140. <span class="n">res_image</span> <span class="o">=</span> <span class="n">draw_segmentation_masks</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="n">overlay</span><span class="p">,</span> <span class="n">colors</span><span class="o">=</span><span class="n">colors</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    141. <span class="n">res_image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">res_image</span><span class="p">[</span><span class="n">ch</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="k">for</span> <span class="n">ch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)],</span> <span class="mi">2</span><span class="p">)</span>
    142. <span class="n">res_image</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">res_image</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">fx</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span> <span class="n">fy</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span>
    143. <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_NEAREST</span><span class="p">)</span>
    144. <span class="k">if</span> <span class="n">checkpoint_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    145. <span class="k">return</span> <span class="n">res_image</span>
    146. <span class="k">else</span><span class="p">:</span>
    147. <span class="n">cv2</span><span class="o">.</span><span class="n">imwrite</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">image_name</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;.jpg&#39;</span><span class="p">),</span> <span class="n">res_image</span><span class="p">)</span>
    148. <div class="viewcode-block" id="BinarySegmentationVisualization.visualize_batch"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.BinarySegmentationVisualization.visualize_batch">[docs]</a> <span class="nd">@staticmethod</span>
    149. <span class="k">def</span> <span class="nf">visualize_batch</span><span class="p">(</span><span class="n">image_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    150. <span class="n">batch_name</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    151. <span class="n">undo_preprocessing_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="n">reverse_imagenet_preprocessing</span><span class="p">,</span>
    152. <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">):</span>
    153. <span class="sd">&quot;&quot;&quot;</span>
    154. <span class="sd"> A helper function to visualize detections predicted by a network:</span>
    155. <span class="sd"> saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.</span>
    156. <span class="sd"> Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.</span>
    157. <span class="sd"> :param image_tensor: rgb images, (B, H, W, 3)</span>
    158. <span class="sd"> :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),</span>
    159. <span class="sd"> values on dim 1 are: x1, y1, x2, y2, confidence, class</span>
    160. <span class="sd"> :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h</span>
    161. <span class="sd"> (coordinates scaled to [0, 1])</span>
    162. <span class="sd"> :param batch_name: id of the current batch to use for image naming</span>
    163. <span class="sd"> :param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will</span>
    164. <span class="sd"> be returns as a list of numpy image arrays</span>
    165. <span class="sd"> :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images</span>
    166. <span class="sd"> :param image_scale: scale factor for output image</span>
    167. <span class="sd"> &quot;&quot;&quot;</span>
    168. <span class="n">image_np</span> <span class="o">=</span> <span class="n">undo_preprocessing_func</span><span class="p">(</span><span class="n">image_tensor</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
    169. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:])</span> <span class="c1"># comment out</span>
    170. <span class="n">out_images</span> <span class="o">=</span> <span class="p">[]</span>
    171. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">image_np</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
    172. <span class="n">preds</span> <span class="o">=</span> <span class="n">pred_mask</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    173. <span class="n">targets</span> <span class="o">=</span> <span class="n">target_mask</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    174. <span class="n">image_name</span> <span class="o">=</span> <span class="s1">&#39;_&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="nb">str</span><span class="p">(</span><span class="n">batch_name</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)])</span>
    175. <span class="n">res_image</span> <span class="o">=</span> <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">preds</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">,</span>
    176. <span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">image_name</span><span class="p">)</span>
    177. <span class="k">if</span> <span class="n">res_image</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    178. <span class="n">out_images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">res_image</span><span class="p">)</span>
    179. <span class="k">return</span> <span class="n">out_images</span></div></div>
    180. <div class="viewcode-block" id="visualize_batches"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.visualize_batches">[docs]</a><span class="k">def</span> <span class="nf">visualize_batches</span><span class="p">(</span><span class="n">dataloader</span><span class="p">,</span> <span class="n">module</span><span class="p">,</span> <span class="n">visualization_path</span><span class="p">,</span> <span class="n">num_batches</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">undo_preprocessing_func</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    181. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">visualization_path</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    182. <span class="k">for</span> <span class="n">batch_i</span><span class="p">,</span> <span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
    183. <span class="k">if</span> <span class="n">batch_i</span> <span class="o">==</span> <span class="n">num_batches</span><span class="p">:</span>
    184. <span class="k">return</span>
    185. <span class="n">imgs</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">))</span>
    186. <span class="n">targets</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">))</span>
    187. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">module</span><span class="p">(</span><span class="n">imgs</span><span class="p">)</span>
    188. <span class="c1"># Visualize the batch</span>
    189. <span class="k">if</span> <span class="n">undo_preprocessing_func</span><span class="p">:</span>
    190. <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_i</span><span class="p">,</span> <span class="n">visualization_path</span><span class="p">,</span>
    191. <span class="n">undo_preprocessing_func</span><span class="o">=</span><span class="n">undo_preprocessing_func</span><span class="p">)</span>
    192. <span class="k">else</span><span class="p">:</span>
    193. <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_i</span><span class="p">,</span> <span class="n">visualization_path</span><span class="p">)</span></div>
    194. <div class="viewcode-block" id="one_hot_to_binary_edge"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.one_hot_to_binary_edge">[docs]</a><span class="k">def</span> <span class="nf">one_hot_to_binary_edge</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    195. <span class="n">kernel_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    196. <span class="n">flatten_channels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    197. <span class="sd">&quot;&quot;&quot;</span>
    198. <span class="sd"> Utils function to create edge feature maps.</span>
    199. <span class="sd"> :param x: input tensor, must be one_hot tensor with shape [B, C, H, W]</span>
    200. <span class="sd"> :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as</span>
    201. <span class="sd"> follows: `edge_width = kernel - 1`</span>
    202. <span class="sd"> :param flatten_channels: Whether to apply logical_or across channels dimension, if at least one pixel class is</span>
    203. <span class="sd"> considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else</span>
    204. <span class="sd"> [B, 1, H, W]. Default is `True`.</span>
    205. <span class="sd"> :return: one_hot edge torch.Tensor.</span>
    206. <span class="sd"> &quot;&quot;&quot;</span>
    207. <span class="k">if</span> <span class="n">kernel_size</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">kernel_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
    208. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;kernel size must be an odd positive values, such as [1, 3, 5, ..], found: </span><span class="si">{</span><span class="n">kernel_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
    209. <span class="n">_kernel</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    210. <span class="n">padding</span> <span class="o">=</span> <span class="p">(</span><span class="n">kernel_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span>
    211. <span class="c1"># Use replicate padding to prevent class shifting and edge formation at the image boundaries.</span>
    212. <span class="n">padded_x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;replicate&quot;</span><span class="p">,</span> <span class="n">pad</span><span class="o">=</span><span class="p">[</span><span class="n">padding</span><span class="p">]</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span>
    213. <span class="c1"># The binary edges feature map is created by subtracting dilated features from erosed features.</span>
    214. <span class="c1"># First the positive one value masks are expanded (dilation) by applying a sliding window filter of one values.</span>
    215. <span class="c1"># The resulted output is then clamped to binary format to [0, 1], this way the one-hot boundaries are expanded by</span>
    216. <span class="c1"># (kernel_size - 1) / 2.</span>
    217. <span class="n">dilation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span>
    218. <span class="n">F</span><span class="o">.</span><span class="n">conv2d</span><span class="p">(</span><span class="n">padded_x</span><span class="p">,</span> <span class="n">_kernel</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)),</span>
    219. <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
    220. <span class="p">)</span>
    221. <span class="c1"># Similar to dilation, erosion (can be seen as inverse of dilation) is applied to contract the one-hot features by</span>
    222. <span class="c1"># applying a dilation operation on the inverse of the one-hot features.</span>
    223. <span class="n">erosion</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span>
    224. <span class="n">F</span><span class="o">.</span><span class="n">conv2d</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">padded_x</span><span class="p">,</span> <span class="n">_kernel</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)),</span>
    225. <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
    226. <span class="p">)</span>
    227. <span class="c1"># Finally the edge features are the result of subtracting dilation by erosion.</span>
    228. <span class="c1"># i.e for a simple 1D one-hot input: [0, 0, 0, 1, 1, 1, 0, 0, 0], using sliding kernel with size 3: [1, 1, 1]</span>
    229. <span class="c1"># Dilated features: [0, 0, 1, 1, 1, 1, 1, 0, 0]</span>
    230. <span class="c1"># Erosed inverse features: [0, 0, 0, 0, 1, 0, 0, 0, 0]</span>
    231. <span class="c1"># Edge features: dilation - erosion: [0, 0, 1, 1, 0, 1, 1, 0, 0]</span>
    232. <span class="n">edge</span> <span class="o">=</span> <span class="n">dilation</span> <span class="o">-</span> <span class="n">erosion</span>
    233. <span class="k">if</span> <span class="n">flatten_channels</span><span class="p">:</span>
    234. <span class="c1"># use max operator across channels. Equivalent to logical or for input with binary values [0, 1].</span>
    235. <span class="n">edge</span> <span class="o">=</span> <span class="n">edge</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    236. <span class="k">return</span> <span class="n">edge</span></div>
    237. <div class="viewcode-block" id="target_to_binary_edge"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.target_to_binary_edge">[docs]</a><span class="k">def</span> <span class="nf">target_to_binary_edge</span><span class="p">(</span><span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
    238. <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    239. <span class="n">kernel_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    240. <span class="n">ignore_index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    241. <span class="n">flatten_channels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
    242. <span class="sd">&quot;&quot;&quot;</span>
    243. <span class="sd"> Utils function to create edge feature maps from target.</span>
    244. <span class="sd"> :param target: Class labels long tensor, with shape [N, H, W]</span>
    245. <span class="sd"> :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot</span>
    246. <span class="sd"> result.</span>
    247. <span class="sd"> :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as</span>
    248. <span class="sd"> follows: `edge_width = kernel - 1`</span>
    249. <span class="sd"> :param flatten_channels: Whether to apply logical or across channels dimension, if at least one pixel class is</span>
    250. <span class="sd"> considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else</span>
    251. <span class="sd"> [B, 1, H, W]. Default is `True`.</span>
    252. <span class="sd"> :return: one_hot edge torch.Tensor.</span>
    253. <span class="sd"> &quot;&quot;&quot;</span>
    254. <span class="n">one_hot</span> <span class="o">=</span> <span class="n">to_one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">)</span>
    255. <span class="k">return</span> <span class="n">one_hot_to_binary_edge</span><span class="p">(</span><span class="n">one_hot</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">flatten_channels</span><span class="o">=</span><span class="n">flatten_channels</span><span class="p">)</span></div>
    256. </pre></div>
    257. </div>
    258. </div>
    259. <footer>
    260. <hr/>
    261. <div role="contentinfo">
    262. <p>&#169; Copyright 2021, SuperGradients team.</p>
    263. </div>
    264. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    265. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    266. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    267. </footer>
    268. </div>
    269. </div>
    270. </section>
    271. </div>
    272. <script>
    273. jQuery(function () {
    274. SphinxRtdTheme.Navigation.enable(true);
    275. });
    276. </script>
    277. </body>
    278. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    417
    418
    419
    420
    421
    422
    423
    424
    425
    426
    427
    428
    429
    430
    431
    432
    433
    434
    435
    436
    437
    438
    439
    440
    441
    442
    443
    444
    445
    446
    447
    448
    449
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.sg_model_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.sg_model_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.sg_model_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">os</span>
    84. <span class="kn">import</span> <span class="nn">sys</span>
    85. <span class="kn">import</span> <span class="nn">socket</span>
    86. <span class="kn">import</span> <span class="nn">time</span>
    87. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
    88. <span class="kn">from</span> <span class="nn">multiprocessing</span> <span class="kn">import</span> <span class="n">Process</span>
    89. <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
    90. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span>
    91. <span class="kn">import</span> <span class="nn">random</span>
    92. <span class="kn">from</span> <span class="nn">treelib</span> <span class="kn">import</span> <span class="n">Tree</span>
    93. <span class="kn">from</span> <span class="nn">termcolor</span> <span class="kn">import</span> <span class="n">colored</span>
    94. <span class="kn">import</span> <span class="nn">torch</span>
    95. <span class="kn">from</span> <span class="nn">torch.utils.tensorboard</span> <span class="kn">import</span> <span class="n">SummaryWriter</span>
    96. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">UnsupportedBatchItemsFormat</span>
    97. <span class="c1"># TODO: These utils should move to sg_model package as internal (private) helper functions</span>
    98. <span class="n">IS_BETTER_COLOR</span> <span class="o">=</span> <span class="p">{</span><span class="kc">True</span><span class="p">:</span> <span class="s2">&quot;green&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">:</span> <span class="s2">&quot;red&quot;</span><span class="p">}</span>
    99. <span class="n">IS_GREATER_SYMBOLS</span> <span class="o">=</span> <span class="p">{</span><span class="kc">True</span><span class="p">:</span> <span class="s2">&quot;↗&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">:</span> <span class="s2">&quot;↘&quot;</span><span class="p">}</span>
    100. <div class="viewcode-block" id="MonitoredValue"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.MonitoredValue">[docs]</a><span class="nd">@dataclass</span>
    101. <span class="k">class</span> <span class="nc">MonitoredValue</span><span class="p">:</span>
    102. <span class="sd">&quot;&quot;&quot;Store a value and some indicators relative to its past iterations.</span>
    103. <span class="sd"> The value can be a metric/loss, and the iteration can be epochs/batch.</span>
    104. <span class="sd"> &quot;&quot;&quot;</span>
    105. <span class="n">name</span><span class="p">:</span> <span class="nb">str</span>
    106. <span class="n">greater_is_better</span><span class="p">:</span> <span class="nb">bool</span>
    107. <span class="n">current</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
    108. <span class="n">previous</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
    109. <span class="n">best</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
    110. <span class="n">change_from_previous</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
    111. <span class="n">change_from_best</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
    112. <span class="nd">@property</span>
    113. <span class="k">def</span> <span class="nf">is_better_than_previous</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    114. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    115. <span class="k">return</span> <span class="kc">None</span>
    116. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
    117. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_previous</span> <span class="o">&gt;=</span> <span class="mi">0</span>
    118. <span class="k">else</span><span class="p">:</span>
    119. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_previous</span> <span class="o">&lt;</span> <span class="mi">0</span>
    120. <span class="nd">@property</span>
    121. <span class="k">def</span> <span class="nf">is_best_value</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    122. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    123. <span class="k">return</span> <span class="kc">None</span>
    124. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
    125. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="o">&gt;=</span> <span class="mi">0</span>
    126. <span class="k">else</span><span class="p">:</span>
    127. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="o">&lt;</span> <span class="mi">0</span></div>
    128. <div class="viewcode-block" id="update_monitored_value"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.update_monitored_value">[docs]</a><span class="k">def</span> <span class="nf">update_monitored_value</span><span class="p">(</span><span class="n">previous_monitored_value</span><span class="p">:</span> <span class="n">MonitoredValue</span><span class="p">,</span> <span class="n">new_value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MonitoredValue</span><span class="p">:</span>
    129. <span class="sd">&quot;&quot;&quot;Update the given ValueToMonitor object (could be a loss or a metric) with the new value</span>
    130. <span class="sd"> :param previous_monitored_value: The stats about the value that is monitored throughout epochs.</span>
    131. <span class="sd"> :param new_value: The value of the current epoch that will be used to update previous_monitored_value</span>
    132. <span class="sd"> :return:</span>
    133. <span class="sd"> &quot;&quot;&quot;</span>
    134. <span class="n">previous_value</span><span class="p">,</span> <span class="n">previous_best_value</span> <span class="o">=</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">current</span><span class="p">,</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">best</span>
    135. <span class="n">name</span><span class="p">,</span> <span class="n">greater_is_better</span> <span class="o">=</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">greater_is_better</span>
    136. <span class="k">if</span> <span class="n">previous_best_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    137. <span class="n">previous_best_value</span> <span class="o">=</span> <span class="n">previous_value</span>
    138. <span class="k">elif</span> <span class="n">greater_is_better</span><span class="p">:</span>
    139. <span class="n">previous_best_value</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">previous_value</span><span class="p">,</span> <span class="n">previous_best_value</span><span class="p">)</span>
    140. <span class="k">else</span><span class="p">:</span>
    141. <span class="n">previous_best_value</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">previous_value</span><span class="p">,</span> <span class="n">previous_best_value</span><span class="p">)</span>
    142. <span class="k">if</span> <span class="n">previous_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    143. <span class="n">change_from_previous</span> <span class="o">=</span> <span class="kc">None</span>
    144. <span class="n">change_from_best</span> <span class="o">=</span> <span class="kc">None</span>
    145. <span class="k">else</span><span class="p">:</span>
    146. <span class="n">change_from_previous</span> <span class="o">=</span> <span class="n">new_value</span> <span class="o">-</span> <span class="n">previous_value</span>
    147. <span class="n">change_from_best</span> <span class="o">=</span> <span class="n">new_value</span> <span class="o">-</span> <span class="n">previous_best_value</span>
    148. <span class="k">return</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="n">current</span><span class="o">=</span><span class="n">new_value</span><span class="p">,</span> <span class="n">previous</span><span class="o">=</span><span class="n">previous_value</span><span class="p">,</span> <span class="n">best</span><span class="o">=</span><span class="n">previous_best_value</span><span class="p">,</span>
    149. <span class="n">change_from_previous</span><span class="o">=</span><span class="n">change_from_previous</span><span class="p">,</span> <span class="n">change_from_best</span><span class="o">=</span><span class="n">change_from_best</span><span class="p">,</span>
    150. <span class="n">greater_is_better</span><span class="o">=</span><span class="n">greater_is_better</span><span class="p">)</span></div>
    151. <div class="viewcode-block" id="update_monitored_values_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.update_monitored_values_dict">[docs]</a><span class="k">def</span> <span class="nf">update_monitored_values_dict</span><span class="p">(</span><span class="n">monitored_values_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">],</span>
    152. <span class="n">new_values_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">]:</span>
    153. <span class="sd">&quot;&quot;&quot;Update the given ValueToMonitor object (could be a loss or a metric) with the new value</span>
    154. <span class="sd"> :param monitored_values_dict: Dict mapping value names to their stats throughout epochs.</span>
    155. <span class="sd"> :param new_values_dict: Dict mapping value names to their new (i.e. current epoch) value.</span>
    156. <span class="sd"> :return: Updated monitored_values_dict</span>
    157. <span class="sd"> &quot;&quot;&quot;</span>
    158. <span class="k">for</span> <span class="n">monitored_value_name</span> <span class="ow">in</span> <span class="n">monitored_values_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
    159. <span class="n">monitored_values_dict</span><span class="p">[</span><span class="n">monitored_value_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">update_monitored_value</span><span class="p">(</span>
    160. <span class="n">new_value</span><span class="o">=</span><span class="n">new_values_dict</span><span class="p">[</span><span class="n">monitored_value_name</span><span class="p">],</span>
    161. <span class="n">previous_monitored_value</span><span class="o">=</span><span class="n">monitored_values_dict</span><span class="p">[</span><span class="n">monitored_value_name</span><span class="p">],</span>
    162. <span class="p">)</span>
    163. <span class="k">return</span> <span class="n">monitored_values_dict</span></div>
    164. <div class="viewcode-block" id="display_epoch_summary"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.display_epoch_summary">[docs]</a><span class="k">def</span> <span class="nf">display_epoch_summary</span><span class="p">(</span><span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_digits</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
    165. <span class="n">train_monitored_values</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">],</span>
    166. <span class="n">valid_monitored_values</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
    167. <span class="sd">&quot;&quot;&quot;Display a summary of loss/metric of interest, for a given epoch.</span>
    168. <span class="sd"> :param epoch: the number of epoch.</span>
    169. <span class="sd"> :param n_digits: number of digits to display on screen for float values</span>
    170. <span class="sd"> :param train_monitored_values: mapping of loss/metric with their stats that will be displayed</span>
    171. <span class="sd"> :param valid_monitored_values: mapping of loss/metric with their stats that will be displayed</span>
    172. <span class="sd"> :return:</span>
    173. <span class="sd"> &quot;&quot;&quot;</span>
    174. <span class="k">def</span> <span class="nf">_format_to_str</span><span class="p">(</span><span class="n">val</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
    175. <span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">n_digits</span><span class="p">))</span>
    176. <span class="k">def</span> <span class="nf">_generate_tree</span><span class="p">(</span><span class="n">value_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">monitored_value</span><span class="p">:</span> <span class="n">MonitoredValue</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tree</span><span class="p">:</span>
    177. <span class="sd">&quot;&quot;&quot;Generate a tree that represents the stats of a given loss/metric.&quot;&quot;&quot;</span>
    178. <span class="n">current</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">current</span><span class="p">)</span>
    179. <span class="n">root_id</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="nb">hash</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">value_name</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">))</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">())</span>
    180. <span class="n">tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
    181. <span class="n">tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">value_name</span><span class="o">.</span><span class="n">capitalize</span><span class="p">()</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">identifier</span><span class="o">=</span><span class="n">root_id</span><span class="p">)</span>
    182. <span class="k">if</span> <span class="n">monitored_value</span><span class="o">.</span><span class="n">previous</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    183. <span class="n">previous</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">previous</span><span class="p">)</span>
    184. <span class="n">best</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">best</span><span class="p">)</span>
    185. <span class="n">change_from_previous</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_previous</span><span class="p">)</span>
    186. <span class="n">change_from_best</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_best</span><span class="p">)</span>
    187. <span class="n">diff_with_prev_colored</span> <span class="o">=</span> <span class="n">colored</span><span class="p">(</span>
    188. <span class="n">text</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">IS_GREATER_SYMBOLS</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_previous</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">change_from_previous</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
    189. <span class="n">color</span><span class="o">=</span><span class="n">IS_BETTER_COLOR</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">is_better_than_previous</span><span class="p">]</span>
    190. <span class="p">)</span>
    191. <span class="n">diff_with_best_colored</span> <span class="o">=</span> <span class="n">colored</span><span class="p">(</span>
    192. <span class="n">text</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">IS_GREATER_SYMBOLS</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_best</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">change_from_best</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
    193. <span class="n">color</span><span class="o">=</span><span class="n">IS_BETTER_COLOR</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">is_best_value</span><span class="p">]</span>
    194. <span class="p">)</span>
    195. <span class="n">tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span>
    196. <span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;Epoch N-1 = </span><span class="si">{</span><span class="n">previous</span><span class="si">:</span><span class="s2">6</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">diff_with_prev_colored</span><span class="si">:</span><span class="s2">8</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">,</span>
    197. <span class="n">identifier</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;0_previous_</span><span class="si">{</span><span class="n">root_id</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
    198. <span class="n">parent</span><span class="o">=</span><span class="n">root_id</span>
    199. <span class="p">)</span>
    200. <span class="n">tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span>
    201. <span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;Best until now = </span><span class="si">{</span><span class="n">best</span><span class="si">:</span><span class="s2">6</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">diff_with_best_colored</span><span class="si">:</span><span class="s2">8</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">,</span>
    202. <span class="n">identifier</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;1_best_</span><span class="si">{</span><span class="n">root_id</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
    203. <span class="n">parent</span><span class="o">=</span><span class="n">root_id</span>
    204. <span class="p">)</span>
    205. <span class="k">return</span> <span class="n">tree</span>
    206. <span class="n">train_tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
    207. <span class="n">train_tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="s2">&quot;Training&quot;</span><span class="p">,</span> <span class="s2">&quot;Training&quot;</span><span class="p">)</span>
    208. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">train_monitored_values</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    209. <span class="n">train_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s1">&#39;Training&#39;</span><span class="p">,</span> <span class="n">new_tree</span><span class="o">=</span><span class="n">_generate_tree</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">monitored_value</span><span class="o">=</span><span class="n">value</span><span class="p">))</span>
    210. <span class="n">valid_tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
    211. <span class="n">valid_tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="s2">&quot;Validation&quot;</span><span class="p">,</span> <span class="s2">&quot;Validation&quot;</span><span class="p">)</span>
    212. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">valid_monitored_values</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    213. <span class="n">valid_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s1">&#39;Validation&#39;</span><span class="p">,</span> <span class="n">new_tree</span><span class="o">=</span><span class="n">_generate_tree</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">monitored_value</span><span class="o">=</span><span class="n">value</span><span class="p">))</span>
    214. <span class="n">summary_tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
    215. <span class="n">summary_tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;SUMMARY OF EPOCH </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="s2">&quot;Summary&quot;</span><span class="p">)</span>
    216. <span class="n">summary_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s2">&quot;Summary&quot;</span><span class="p">,</span> <span class="n">train_tree</span><span class="p">)</span>
    217. <span class="n">summary_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s2">&quot;Summary&quot;</span><span class="p">,</span> <span class="n">valid_tree</span><span class="p">)</span>
    218. <span class="n">summary_tree</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></div>
    219. <div class="viewcode-block" id="try_port"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.try_port">[docs]</a><span class="k">def</span> <span class="nf">try_port</span><span class="p">(</span><span class="n">port</span><span class="p">):</span>
    220. <span class="sd">&quot;&quot;&quot;</span>
    221. <span class="sd"> try_port - Helper method for tensorboard port binding</span>
    222. <span class="sd"> :param port:</span>
    223. <span class="sd"> :return:</span>
    224. <span class="sd"> &quot;&quot;&quot;</span>
    225. <span class="n">sock</span> <span class="o">=</span> <span class="n">socket</span><span class="o">.</span><span class="n">socket</span><span class="p">(</span><span class="n">socket</span><span class="o">.</span><span class="n">AF_INET</span><span class="p">,</span> <span class="n">socket</span><span class="o">.</span><span class="n">SOCK_STREAM</span><span class="p">)</span>
    226. <span class="n">is_port_available</span> <span class="o">=</span> <span class="kc">False</span>
    227. <span class="k">try</span><span class="p">:</span>
    228. <span class="n">sock</span><span class="o">.</span><span class="n">bind</span><span class="p">((</span><span class="s2">&quot;localhost&quot;</span><span class="p">,</span> <span class="n">port</span><span class="p">))</span>
    229. <span class="n">is_port_available</span> <span class="o">=</span> <span class="kc">True</span>
    230. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
    231. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Port &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">port</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39; is in use&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">))</span>
    232. <span class="n">sock</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
    233. <span class="k">return</span> <span class="n">is_port_available</span></div>
    234. <div class="viewcode-block" id="launch_tensorboard_process"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.launch_tensorboard_process">[docs]</a><span class="k">def</span> <span class="nf">launch_tensorboard_process</span><span class="p">(</span><span class="n">checkpoints_dir_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sleep_postpone</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">port</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Process</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
    235. <span class="sd">&quot;&quot;&quot;</span>
    236. <span class="sd"> launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them</span>
    237. <span class="sd"> unless port is defined by the user</span>
    238. <span class="sd"> :param checkpoints_dir_path:</span>
    239. <span class="sd"> :param sleep_postpone:</span>
    240. <span class="sd"> :param port:</span>
    241. <span class="sd"> :return: tuple of tb process, port</span>
    242. <span class="sd"> &quot;&quot;&quot;</span>
    243. <span class="n">logdir_path</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">checkpoints_dir_path</span><span class="p">)</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">absolute</span><span class="p">())</span>
    244. <span class="n">tb_cmd</span> <span class="o">=</span> <span class="s1">&#39;tensorboard --logdir=&#39;</span> <span class="o">+</span> <span class="n">logdir_path</span> <span class="o">+</span> <span class="s1">&#39; --bind_all&#39;</span>
    245. <span class="k">if</span> <span class="n">port</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    246. <span class="n">tb_ports</span> <span class="o">=</span> <span class="p">[</span><span class="n">port</span><span class="p">]</span>
    247. <span class="k">else</span><span class="p">:</span>
    248. <span class="n">tb_ports</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="mi">6006</span><span class="p">,</span> <span class="mi">6016</span><span class="p">)</span>
    249. <span class="k">for</span> <span class="n">tb_port</span> <span class="ow">in</span> <span class="n">tb_ports</span><span class="p">:</span>
    250. <span class="k">if</span> <span class="ow">not</span> <span class="n">try_port</span><span class="p">(</span><span class="n">tb_port</span><span class="p">):</span>
    251. <span class="k">continue</span>
    252. <span class="k">else</span><span class="p">:</span>
    253. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Starting Tensor-Board process on port: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">tb_port</span><span class="p">))</span>
    254. <span class="n">tensor_board_process</span> <span class="o">=</span> <span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">system</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">([</span><span class="n">tb_cmd</span> <span class="o">+</span> <span class="s1">&#39; --port=&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">tb_port</span><span class="p">)]))</span>
    255. <span class="n">tensor_board_process</span><span class="o">.</span><span class="n">daemon</span> <span class="o">=</span> <span class="kc">True</span>
    256. <span class="n">tensor_board_process</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
    257. <span class="c1"># LET THE TENSORBOARD PROCESS START</span>
    258. <span class="k">if</span> <span class="n">sleep_postpone</span><span class="p">:</span>
    259. <span class="n">time</span><span class="o">.</span><span class="n">sleep</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
    260. <span class="k">return</span> <span class="n">tensor_board_process</span><span class="p">,</span> <span class="n">tb_port</span>
    261. <span class="c1"># RETURNING IRRELEVANT VALUES</span>
    262. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Failed to initialize Tensor-Board process on port: &#39;</span> <span class="o">+</span> <span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">tb_ports</span><span class="p">)))</span>
    263. <span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span></div>
    264. <div class="viewcode-block" id="init_summary_writer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.init_summary_writer">[docs]</a><span class="k">def</span> <span class="nf">init_summary_writer</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">,</span> <span class="n">checkpoint_loaded</span><span class="p">,</span> <span class="n">user_prompt</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    265. <span class="sd">&quot;&quot;&quot;Remove previous tensorboard files from directory and launch a tensor board process&quot;&quot;&quot;</span>
    266. <span class="c1"># If the training is from scratch, Walk through destination folder and delete existing tensorboard logs</span>
    267. <span class="n">user</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
    268. <span class="k">if</span> <span class="ow">not</span> <span class="n">checkpoint_loaded</span><span class="p">:</span>
    269. <span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">):</span>
    270. <span class="k">if</span> <span class="s1">&#39;events&#39;</span> <span class="ow">in</span> <span class="n">filename</span><span class="p">:</span>
    271. <span class="k">if</span> <span class="ow">not</span> <span class="n">user_prompt</span><span class="p">:</span>
    272. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;&quot;</span><span class="si">{}</span><span class="s1">&quot; will not be deleted&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span>
    273. <span class="k">continue</span>
    274. <span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
    275. <span class="c1"># Verify with user before deleting old tensorboard files</span>
    276. <span class="n">user</span> <span class="o">=</span> <span class="nb">input</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">OLDER TENSORBOARD FILES EXISTS IN EXPERIMENT FOLDER:</span><span class="se">\n</span><span class="s1">&quot;</span><span class="si">{}</span><span class="s1">&quot;</span><span class="se">\n</span><span class="s1">&#39;</span>
    277. <span class="s1">&#39;DO YOU WANT TO DELETE THEM? [y/n]&#39;</span>
    278. <span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span> <span class="k">if</span> <span class="p">(</span><span class="n">user</span> <span class="o">!=</span> <span class="s1">&#39;n&#39;</span> <span class="ow">or</span> <span class="n">user</span> <span class="o">!=</span> <span class="s1">&#39;y&#39;</span><span class="p">)</span> <span class="k">else</span> <span class="n">user</span>
    279. <span class="k">if</span> <span class="n">user</span> <span class="o">==</span> <span class="s1">&#39;y&#39;</span><span class="p">:</span>
    280. <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="s1">&#39;</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
    281. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;DELETED: </span><span class="si">{}</span><span class="s1">!&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span>
    282. <span class="k">break</span>
    283. <span class="k">elif</span> <span class="n">user</span> <span class="o">==</span> <span class="s1">&#39;n&#39;</span><span class="p">:</span>
    284. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;&quot;</span><span class="si">{}</span><span class="s1">&quot; will not be deleted&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span>
    285. <span class="k">break</span>
    286. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Unknown answer...&#39;</span><span class="p">)</span>
    287. <span class="c1"># Launch a tensorboard process</span>
    288. <span class="k">return</span> <span class="n">SummaryWriter</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">)</span></div>
    289. <div class="viewcode-block" id="add_log_to_file"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.add_log_to_file">[docs]</a><span class="k">def</span> <span class="nf">add_log_to_file</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">):</span>
    290. <span class="sd">&quot;&quot;&quot;Add a message to the log file&quot;&quot;&quot;</span>
    291. <span class="c1"># -Note: opening and closing the file every time is in-efficient. It is done for experimental purposes</span>
    292. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="s1">&#39;a&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
    293. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">Epoch (</span><span class="si">%d</span><span class="s1">/</span><span class="si">%d</span><span class="s1">) - &#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">))</span>
    294. <span class="k">for</span> <span class="n">result_title</span><span class="p">,</span> <span class="n">result_value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">):</span>
    295. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result_value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
    296. <span class="n">result_value</span> <span class="o">=</span> <span class="n">result_value</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
    297. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">result_title</span> <span class="o">+</span> <span class="s1">&#39;: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">result_value</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\t</span><span class="s1">&#39;</span><span class="p">)</span></div>
    298. <div class="viewcode-block" id="write_training_results"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.write_training_results">[docs]</a><span class="k">def</span> <span class="nf">write_training_results</span><span class="p">(</span><span class="n">writer</span><span class="p">,</span> <span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span>
    299. <span class="sd">&quot;&quot;&quot;Stores the training and validation loss and accuracy for current epoch in a tensorboard file&quot;&quot;&quot;</span>
    300. <span class="k">for</span> <span class="n">res_key</span><span class="p">,</span> <span class="n">res_val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">):</span>
    301. <span class="c1"># USE ONLY LOWER-CASE LETTERS AND REPLACE SPACES WITH &#39;_&#39; TO AVOID MANY TITLES FOR THE SAME KEY</span>
    302. <span class="n">corrected_res_key</span> <span class="o">=</span> <span class="n">res_key</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39; &#39;</span><span class="p">,</span> <span class="s1">&#39;_&#39;</span><span class="p">)</span>
    303. <span class="n">writer</span><span class="o">.</span><span class="n">add_scalar</span><span class="p">(</span><span class="n">corrected_res_key</span><span class="p">,</span> <span class="n">res_val</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
    304. <span class="n">writer</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span></div>
    305. <div class="viewcode-block" id="write_hpms"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.write_hpms">[docs]</a><span class="k">def</span> <span class="nf">write_hpms</span><span class="p">(</span><span class="n">writer</span><span class="p">,</span> <span class="n">hpmstructs</span><span class="o">=</span><span class="p">[],</span> <span class="n">special_conf</span><span class="o">=</span><span class="p">{}):</span>
    306. <span class="sd">&quot;&quot;&quot;Stores the training and dataset hyper params in the tensorboard file&quot;&quot;&quot;</span>
    307. <span class="n">hpm_string</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
    308. <span class="k">for</span> <span class="n">hpm</span> <span class="ow">in</span> <span class="n">hpmstructs</span><span class="p">:</span>
    309. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">hpm</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    310. <span class="n">hpm_string</span> <span class="o">+=</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">: </span><span class="si">{}</span><span class="s1"> </span><span class="se">\n</span><span class="s1"> &#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span>
    311. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">special_conf</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    312. <span class="n">hpm_string</span> <span class="o">+=</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">: </span><span class="si">{}</span><span class="s1"> </span><span class="se">\n</span><span class="s1"> &#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span>
    313. <span class="n">writer</span><span class="o">.</span><span class="n">add_text</span><span class="p">(</span><span class="s2">&quot;Hyper_parameters&quot;</span><span class="p">,</span> <span class="n">hpm_string</span><span class="p">)</span>
    314. <span class="n">writer</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span></div>
    315. <span class="c1"># TODO: This should probably move into datasets/datasets_utils.py?</span>
    316. <div class="viewcode-block" id="unpack_batch_items"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.unpack_batch_items">[docs]</a><span class="k">def</span> <span class="nf">unpack_batch_items</span><span class="p">(</span><span class="n">batch_items</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
    317. <span class="sd">&quot;&quot;&quot;</span>
    318. <span class="sd"> Adds support for unpacking batch items in train/validation loop.</span>
    319. <span class="sd"> @param batch_items: (Union[tuple, torch.Tensor]) returned by the data loader, which is expected to be in one of</span>
    320. <span class="sd"> the following formats:</span>
    321. <span class="sd"> 1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(batch_items) = 2</span>
    322. <span class="sd"> 2. tuple: (inputs, targets, additional_batch_items)</span>
    323. <span class="sd"> where inputs are fed to the network, targets are their corresponding labels and additional_batch_items is a</span>
    324. <span class="sd"> dictionary (format {additional_batch_item_i_name: additional_batch_item_i ...}) which can be accessed through</span>
    325. <span class="sd"> the phase context under the attribute additional_batch_item_i_name, using a phase callback.</span>
    326. <span class="sd"> @return: inputs, target, additional_batch_items</span>
    327. <span class="sd"> &quot;&quot;&quot;</span>
    328. <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="p">{}</span>
    329. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
    330. <span class="n">inputs</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch_items</span>
    331. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
    332. <span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="n">batch_items</span>
    333. <span class="k">else</span><span class="p">:</span>
    334. <span class="k">raise</span> <span class="n">UnsupportedBatchItemsFormat</span><span class="p">()</span>
    335. <span class="k">return</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">additional_batch_items</span></div>
    336. <div class="viewcode-block" id="log_uncaught_exceptions"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.log_uncaught_exceptions">[docs]</a><span class="k">def</span> <span class="nf">log_uncaught_exceptions</span><span class="p">(</span><span class="n">logger</span><span class="p">):</span>
    337. <span class="sd">&quot;&quot;&quot;</span>
    338. <span class="sd"> Makes logger log uncaught exceptions</span>
    339. <span class="sd"> @param logger: logging.Logger</span>
    340. <span class="sd"> @return: None</span>
    341. <span class="sd"> &quot;&quot;&quot;</span>
    342. <span class="k">def</span> <span class="nf">handle_exception</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_value</span><span class="p">,</span> <span class="n">exc_traceback</span><span class="p">):</span>
    343. <span class="k">if</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="ne">KeyboardInterrupt</span><span class="p">):</span>
    344. <span class="n">sys</span><span class="o">.</span><span class="n">__excepthook__</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_value</span><span class="p">,</span> <span class="n">exc_traceback</span><span class="p">)</span>
    345. <span class="k">return</span>
    346. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;Uncaught exception&quot;</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_value</span><span class="p">,</span> <span class="n">exc_traceback</span><span class="p">))</span>
    347. <span class="n">sys</span><span class="o">.</span><span class="n">excepthook</span> <span class="o">=</span> <span class="n">handle_exception</span></div>
    348. </pre></div>
    349. </div>
    350. </div>
    351. <footer>
    352. <hr/>
    353. <div role="contentinfo">
    354. <p>&#169; Copyright 2021, SuperGradients team.</p>
    355. </div>
    356. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    357. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    358. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    359. </footer>
    360. </div>
    361. </div>
    362. </section>
    363. </div>
    364. <script>
    365. jQuery(function () {
    366. SphinxRtdTheme.Navigation.enable(true);
    367. });
    368. </script>
    369. </body>
    370. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.ssd_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.ssd_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.ssd_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">itertools</span>
    84. <span class="kn">from</span> <span class="nn">math</span> <span class="kn">import</span> <span class="n">sqrt</span>
    85. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
    86. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    87. <span class="kn">import</span> <span class="nn">torch</span>
    88. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">non_max_suppression</span><span class="p">,</span> <span class="n">NMS_Type</span><span class="p">,</span> \
    89. <span class="n">matrix_non_max_suppression</span><span class="p">,</span> <span class="n">DetectionPostPredictionCallback</span>
    90. <div class="viewcode-block" id="DefaultBoxes"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ssd_utils.DefaultBoxes">[docs]</a><span class="k">class</span> <span class="nc">DefaultBoxes</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    91. <span class="sd">&quot;&quot;&quot;</span>
    92. <span class="sd"> Default Boxes, (aka: anchor boxes or priors boxes) used by SSD model</span>
    93. <span class="sd"> &quot;&quot;&quot;</span>
    94. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fig_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">feat_size</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">scales</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">aspect_ratios</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
    95. <span class="n">scale_xy</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">scale_wh</span><span class="o">=</span><span class="mf">0.2</span><span class="p">):</span>
    96. <span class="sd">&quot;&quot;&quot;</span>
    97. <span class="sd"> For each feature map i (each predicting level, grids) the anchors (a.k.a. default boxes) will be:</span>
    98. <span class="sd"> [</span>
    99. <span class="sd"> [s, s], [sqrt(s * s_next), sqrt(s * s_next)],</span>
    100. <span class="sd"> [s * sqrt(alpha1), s / sqrt(alpha1)], [s / sqrt(alpha1), s * sqrt(alpha1)],</span>
    101. <span class="sd"> ...</span>
    102. <span class="sd"> [s * sqrt(alphaN), s / sqrt(alphaN)], [s / sqrt(alphaN), s * sqrt(alphaN)]</span>
    103. <span class="sd"> ] / fig_size</span>
    104. <span class="sd"> where:</span>
    105. <span class="sd"> * s = scale[i] - this level&#39;s scale</span>
    106. <span class="sd"> * s_next = scale[i + 1] - next level&#39;s scale</span>
    107. <span class="sd"> * alpha1, ... alphaN - this level&#39;s alphas, e.g. [2, 3]</span>
    108. <span class="sd"> * fig_size - input image resolution</span>
    109. <span class="sd"> Because of division by image resolution, the anchors will be in image coordinates normalized to [0, 1]</span>
    110. <span class="sd"> :param fig_size: input image resolution</span>
    111. <span class="sd"> :param feat_size: resolution of all feature maps with predictions (grids)</span>
    112. <span class="sd"> :param scales: anchor sizes in pixels for each feature level;</span>
    113. <span class="sd"> one value per level will be used to generate anchors based on the formula above</span>
    114. <span class="sd"> :param aspect_ratios: lists of alpha values for each feature map</span>
    115. <span class="sd"> :param scale_xy: predicted boxes will be with a factor scale_xy</span>
    116. <span class="sd"> so will be multiplied by scale_xy during post-prediction processing;</span>
    117. <span class="sd"> e.g. scale 0.1 means that prediction will be 10 times bigger</span>
    118. <span class="sd"> (improves predictions quality)</span>
    119. <span class="sd"> :param scale_wh: same logic as in scale_xy, but for width and height.</span>
    120. <span class="sd"> &quot;&quot;&quot;</span>
    121. <span class="bp">self</span><span class="o">.</span><span class="n">feat_size</span> <span class="o">=</span> <span class="n">feat_size</span>
    122. <span class="bp">self</span><span class="o">.</span><span class="n">fig_size</span> <span class="o">=</span> <span class="n">fig_size</span>
    123. <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy_</span> <span class="o">=</span> <span class="n">scale_xy</span>
    124. <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh_</span> <span class="o">=</span> <span class="n">scale_wh</span>
    125. <span class="c1"># According to https://github.com/weiliu89/caffe</span>
    126. <span class="c1"># Calculation method slightly different from paper</span>
    127. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="n">scales</span>
    128. <span class="bp">self</span><span class="o">.</span><span class="n">aspect_ratios</span> <span class="o">=</span> <span class="n">aspect_ratios</span>
    129. <span class="bp">self</span><span class="o">.</span><span class="n">default_boxes</span> <span class="o">=</span> <span class="p">[]</span>
    130. <span class="bp">self</span><span class="o">.</span><span class="n">num_anchors</span> <span class="o">=</span> <span class="p">[]</span>
    131. <span class="c1"># size of feature and number of feature</span>
    132. <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">sfeat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">feat_size</span><span class="p">):</span>
    133. <span class="n">sk1</span> <span class="o">=</span> <span class="n">scales</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
    134. <span class="n">sk2</span> <span class="o">=</span> <span class="n">scales</span><span class="p">[</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
    135. <span class="n">sk3</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">sk1</span> <span class="o">*</span> <span class="n">sk2</span><span class="p">)</span>
    136. <span class="n">all_sizes</span> <span class="o">=</span> <span class="p">[(</span><span class="n">sk1</span><span class="p">,</span> <span class="n">sk1</span><span class="p">),</span> <span class="p">(</span><span class="n">sk3</span><span class="p">,</span> <span class="n">sk3</span><span class="p">)]</span>
    137. <span class="k">for</span> <span class="n">alpha</span> <span class="ow">in</span> <span class="n">aspect_ratios</span><span class="p">[</span><span class="n">idx</span><span class="p">]:</span>
    138. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">sk1</span> <span class="o">*</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span> <span class="n">sk1</span> <span class="o">/</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span>
    139. <span class="n">all_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">))</span>
    140. <span class="n">all_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">))</span>
    141. <span class="n">all_sizes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">all_sizes</span><span class="p">)</span> <span class="o">/</span> <span class="n">fig_size</span>
    142. <span class="bp">self</span><span class="o">.</span><span class="n">num_anchors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">all_sizes</span><span class="p">))</span>
    143. <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">all_sizes</span><span class="p">:</span>
    144. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">sfeat</span><span class="p">),</span> <span class="n">repeat</span><span class="o">=</span><span class="mi">2</span><span class="p">):</span>
    145. <span class="n">cx</span><span class="p">,</span> <span class="n">cy</span> <span class="o">=</span> <span class="p">(</span><span class="n">j</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">sfeat</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">sfeat</span>
    146. <span class="bp">self</span><span class="o">.</span><span class="n">default_boxes</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">cx</span><span class="p">,</span> <span class="n">cy</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">))</span>
    147. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">default_boxes</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
    148. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    149. <span class="c1"># For IoU calculation</span>
    150. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
    151. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span>
    152. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
    153. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span>
    154. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
    155. <span class="nd">@property</span>
    156. <span class="k">def</span> <span class="nf">scale_xy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    157. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy_</span>
    158. <span class="nd">@property</span>
    159. <span class="k">def</span> <span class="nf">scale_wh</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    160. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh_</span>
    161. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="s2">&quot;xyxy&quot;</span><span class="p">):</span>
    162. <span class="k">if</span> <span class="n">order</span> <span class="o">==</span> <span class="s2">&quot;xyxy&quot;</span><span class="p">:</span>
    163. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span>
    164. <span class="k">if</span> <span class="n">order</span> <span class="o">==</span> <span class="s2">&quot;xywh&quot;</span><span class="p">:</span>
    165. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span></div>
    166. <div class="viewcode-block" id="SSDPostPredictCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ssd_utils.SSDPostPredictCallback">[docs]</a><span class="k">class</span> <span class="nc">SSDPostPredictCallback</span><span class="p">(</span><span class="n">DetectionPostPredictionCallback</span><span class="p">):</span>
    167. <span class="sd">&quot;&quot;&quot;</span>
    168. <span class="sd"> post prediction callback module to convert and filter predictions coming from the SSD net to a format</span>
    169. <span class="sd"> used by all other detection models</span>
    170. <span class="sd"> &quot;&quot;&quot;</span>
    171. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">conf</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.001</span><span class="p">,</span> <span class="n">iou</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span><span class="p">,</span> <span class="n">classes</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    172. <span class="n">max_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">300</span><span class="p">,</span>
    173. <span class="n">nms_type</span><span class="p">:</span> <span class="n">NMS_Type</span> <span class="o">=</span> <span class="n">NMS_Type</span><span class="o">.</span><span class="n">ITERATIVE</span><span class="p">,</span>
    174. <span class="n">multi_label_per_box</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    175. <span class="sd">&quot;&quot;&quot;</span>
    176. <span class="sd"> Predictions of SSD contain unnormalized probabilities for a background class,</span>
    177. <span class="sd"> together with confidences for all the dataset classes. Background will be utilized and discarded,</span>
    178. <span class="sd"> so this callback will return 0-based classes without background</span>
    179. <span class="sd"> :param conf: confidence threshold</span>
    180. <span class="sd"> :param iou: IoU threshold</span>
    181. <span class="sd"> :param classes: (optional list) filter by class</span>
    182. <span class="sd"> :param nms_type: the type of nms to use (iterative or matrix)</span>
    183. <span class="sd"> :param multi_label_per_box: whether to use re-use each box with all possible labels</span>
    184. <span class="sd"> (instead of the maximum confidence all confidences above threshold</span>
    185. <span class="sd"> will be sent to NMS)</span>
    186. <span class="sd"> &quot;&quot;&quot;</span>
    187. <span class="nb">super</span><span class="p">(</span><span class="n">SSDPostPredictCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
    188. <span class="bp">self</span><span class="o">.</span><span class="n">conf</span> <span class="o">=</span> <span class="n">conf</span>
    189. <span class="bp">self</span><span class="o">.</span><span class="n">iou</span> <span class="o">=</span> <span class="n">iou</span>
    190. <span class="bp">self</span><span class="o">.</span><span class="n">nms_type</span> <span class="o">=</span> <span class="n">nms_type</span>
    191. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
    192. <span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span> <span class="o">=</span> <span class="n">max_predictions</span>
    193. <span class="bp">self</span><span class="o">.</span><span class="n">multi_label_per_box</span> <span class="o">=</span> <span class="n">multi_label_per_box</span>
    194. <div class="viewcode-block" id="SSDPostPredictCallback.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ssd_utils.SSDPostPredictCallback.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    195. <span class="n">nms_input</span> <span class="o">=</span> <span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    196. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">nms_type</span> <span class="o">==</span> <span class="n">NMS_Type</span><span class="o">.</span><span class="n">ITERATIVE</span><span class="p">:</span>
    197. <span class="n">nms_res</span> <span class="o">=</span> <span class="n">non_max_suppression</span><span class="p">(</span><span class="n">nms_input</span><span class="p">,</span> <span class="n">conf_thres</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">conf</span><span class="p">,</span> <span class="n">iou_thres</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">iou</span><span class="p">,</span>
    198. <span class="n">multi_label_per_box</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_label_per_box</span><span class="p">,</span> <span class="n">with_confidence</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    199. <span class="k">else</span><span class="p">:</span>
    200. <span class="n">nms_res</span> <span class="o">=</span> <span class="n">matrix_non_max_suppression</span><span class="p">(</span><span class="n">nms_input</span><span class="p">,</span> <span class="n">conf_thres</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">conf</span><span class="p">,</span>
    201. <span class="n">max_num_of_detections</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span><span class="p">)</span>
    202. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_filter_max_predictions</span><span class="p">(</span><span class="n">nms_res</span><span class="p">)</span></div>
    203. <span class="k">def</span> <span class="nf">_filter_max_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">res</span><span class="p">:</span> <span class="n">List</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">:</span>
    204. <span class="n">res</span><span class="p">[:]</span> <span class="o">=</span> <span class="p">[</span><span class="n">im</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span><span class="p">]</span> <span class="k">if</span> <span class="p">(</span><span class="n">im</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">im</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span><span class="p">)</span> <span class="k">else</span> <span class="n">im</span> <span class="k">for</span> <span class="n">im</span> <span class="ow">in</span> <span class="n">res</span><span class="p">]</span>
    205. <span class="k">return</span> <span class="n">res</span></div>
    206. </pre></div>
    207. </div>
    208. </div>
    209. <footer>
    210. <hr/>
    211. <div role="contentinfo">
    212. <p>&#169; Copyright 2021, SuperGradients team.</p>
    213. </div>
    214. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    215. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    216. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    217. </footer>
    218. </div>
    219. </div>
    220. </section>
    221. </div>
    222. <script>
    223. jQuery(function () {
    224. SphinxRtdTheme.Navigation.enable(true);
    225. });
    226. </script>
    227. </body>
    228. </html>
    Discard
    Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
    @@ -3,10 +3,11 @@
     <head>
     <head>
       <meta charset="utf-8" />
       <meta charset="utf-8" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
       <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    -  <title>super_gradients.training.utils.utils &mdash; SuperGradients 1.0 documentation</title>
    +  <title>super_gradients.training.utils.utils &mdash; SuperGradients 3.0.3 documentation</title>
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
           <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    +      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
       <!--[if lt IE 9]>
       <!--[if lt IE 9]>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
         <script src="../../../../_static/js/html5shiv.min.js"></script>
       <![endif]-->
       <![endif]-->
    @@ -14,7 +15,9 @@
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/jquery.js"></script>
             <script src="../../../../_static/underscore.js"></script>
             <script src="../../../../_static/underscore.js"></script>
    +        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
             <script src="../../../../_static/doctools.js"></script>
             <script src="../../../../_static/doctools.js"></script>
    +        <script src="../../../../_static/sphinx_highlight.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <script src="../../../../_static/js/theme.js"></script>
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="index" title="Index" href="../../../../genindex.html" />
         <link rel="search" title="Search" href="../../../../search.html" /> 
         <link rel="search" title="Search" href="../../../../search.html" /> 
    @@ -35,30 +38,29 @@
       </form>
       </form>
     </div>
     </div>
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
             </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    -              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    +              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
     <ul>
     <ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    +<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
     </ul>
     </ul>
    -<p class="caption"><span class="caption-text">Technical Documentation</span></p>
    +<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
     <ul>
     <ul>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
     <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    -</ul>
    -<p class="caption"><span class="caption-text">User Guide</span></p>
    -<ul>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    -<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
     </ul>
     </ul>
     
     
             </div>
             </div>
    @@ -90,7 +92,7 @@
     <span class="kn">import</span> <span class="nn">time</span>
     <span class="kn">import</span> <span class="nn">time</span>
     <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">lru_cache</span>
     <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">lru_cache</span>
     <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
     <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
    -<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
    +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span>
     <span class="kn">from</span> <span class="nn">zipfile</span> <span class="kn">import</span> <span class="n">ZipFile</span>
     <span class="kn">from</span> <span class="nn">zipfile</span> <span class="kn">import</span> <span class="n">ZipFile</span>
     <span class="kn">import</span> <span class="nn">os</span>
     <span class="kn">import</span> <span class="nn">os</span>
     <span class="kn">from</span> <span class="nn">jsonschema</span> <span class="kn">import</span> <span class="n">validate</span>
     <span class="kn">from</span> <span class="nn">jsonschema</span> <span class="kn">import</span> <span class="n">validate</span>
    @@ -112,28 +114,40 @@
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="convert_to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.convert_to_tensor">[docs]</a><span class="k">def</span> <span class="nf">convert_to_tensor</span><span class="p">(</span><span class="n">array</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">empty_list</span><span class="p">():</span>
    +    <span class="sd">&quot;&quot;&quot;Instantiate an empty list. This is a workaround to generate a list with a function call in hydra, instead of the &quot;[]&quot;.&quot;&quot;&quot;</span>
    +    <span class="k">return</span> <span class="nb">list</span><span class="p">()</span>
    +
    +
    +<div class="viewcode-block" id="convert_to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.convert_to_tensor">[docs]</a><span class="k">def</span> <span class="nf">convert_to_tensor</span><span class="p">(</span><span class="n">array</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Converts numpy arrays and lists to Torch tensors before calculation losses</span>
         <span class="sd">&quot;&quot;&quot;Converts numpy arrays and lists to Torch tensors before calculation losses</span>
     <span class="sd">    :param array: torch.tensor / Numpy array / List</span>
     <span class="sd">    :param array: torch.tensor / Numpy array / List</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
         <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="o">!=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="k">else</span> <span class="n">array</span></div>
         <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="o">!=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="k">else</span> <span class="n">array</span></div>
     
     
     
     
    -<div class="viewcode-block" id="HpmStruct"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct">[docs]</a><span class="k">class</span> <span class="nc">HpmStruct</span><span class="p">:</span>
    +<div class="viewcode-block" id="HpmStruct"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct">[docs]</a><span class="k">class</span> <span class="nc">HpmStruct</span><span class="p">:</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
             <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">entries</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">entries</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="kc">None</span>
     
     
    -<div class="viewcode-block" id="HpmStruct.set_schema"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.set_schema">[docs]</a>    <span class="k">def</span> <span class="nf">set_schema</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schema</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
    +<div class="viewcode-block" id="HpmStruct.set_schema"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.set_schema">[docs]</a>    <span class="k">def</span> <span class="nf">set_schema</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schema</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="n">schema</span></div>
             <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="n">schema</span></div>
     
     
    -<div class="viewcode-block" id="HpmStruct.override"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.override">[docs]</a>    <span class="k">def</span> <span class="nf">override</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
    +<div class="viewcode-block" id="HpmStruct.override"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.override">[docs]</a>    <span class="k">def</span> <span class="nf">override</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
             <span class="n">recursive_override</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">entries</span><span class="p">)</span></div>
             <span class="n">recursive_override</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">entries</span><span class="p">)</span></div>
     
     
    -<div class="viewcode-block" id="HpmStruct.to_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.to_dict">[docs]</a>    <span class="k">def</span> <span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    -        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span></div>
    +<div class="viewcode-block" id="HpmStruct.to_dict"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.to_dict">[docs]</a>    <span class="k">def</span> <span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">include_schema</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span
    +        <span class="sd">&quot;&quot;&quot;Convert this HpmStruct instance into a dict.</span>
    +<span class="sd">        :param include_schema: If True, also return the field &quot;schema&quot;</span>
    +<span class="sd">        :return: Dict representation of this HpmStruct instance.</span>
    +<span class="sd">        &quot;&quot;&quot;</span>
    +        <span class="n">out_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
    +        <span class="k">if</span> <span class="ow">not</span> <span class="n">include_schema</span><span class="p">:</span>
    +            <span class="n">out_dict</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;schema&quot;</span><span class="p">)</span>
    +        <span class="k">return</span> <span class="n">out_dict</span></div>
     
     
    -<div class="viewcode-block" id="HpmStruct.validate"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.validate">[docs]</a>    <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +<div class="viewcode-block" id="HpmStruct.validate"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.validate">[docs]</a>    <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">        Validate the current dict values according to the provided schema</span>
     <span class="sd">        Validate the current dict values according to the provided schema</span>
     <span class="sd">        :raises</span>
     <span class="sd">        :raises</span>
    @@ -147,16 +161,16 @@
                 <span class="n">validate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">schema</span><span class="p">)</span></div></div>
                 <span class="n">validate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">schema</span><span class="p">)</span></div></div>
     
     
     
     
    -<div class="viewcode-block" id="WrappedModel"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.WrappedModel">[docs]</a><span class="k">class</span> <span class="nc">WrappedModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    +<div class="viewcode-block" id="WrappedModel"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.WrappedModel">[docs]</a><span class="k">class</span> <span class="nc">WrappedModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">):</span>
         <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">):</span>
             <span class="nb">super</span><span class="p">(</span><span class="n">WrappedModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
             <span class="nb">super</span><span class="p">(</span><span class="n">WrappedModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span>  <span class="c1"># that I actually define.</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span>  <span class="c1"># that I actually define.</span>
     
     
    -<div class="viewcode-block" id="WrappedModel.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.WrappedModel.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    +<div class="viewcode-block" id="WrappedModel.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.WrappedModel.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div></div>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div></div>
     
     
     
     
    -<div class="viewcode-block" id="Timer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer">[docs]</a><span class="k">class</span> <span class="nc">Timer</span><span class="p">:</span>
    +<div class="viewcode-block" id="Timer"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.Timer">[docs]</a><span class="k">class</span> <span class="nc">Timer</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;A class to measure time handling both GPU &amp; CPU processes</span>
         <span class="sd">&quot;&quot;&quot;A class to measure time handling both GPU &amp; CPU processes</span>
     <span class="sd">    Returns time in milliseconds&quot;&quot;&quot;</span>
     <span class="sd">    Returns time in milliseconds&quot;&quot;&quot;</span>
     
     
    @@ -174,13 +188,13 @@
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ender</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ender</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
     
     
    -<div class="viewcode-block" id="Timer.start"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer.start">[docs]</a>    <span class="k">def</span> <span class="nf">start</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +<div class="viewcode-block" id="Timer.start"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.Timer.start">[docs]</a>    <span class="k">def</span> <span class="nf">start</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">starter</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span></div>
                 <span class="bp">self</span><span class="o">.</span><span class="n">starter</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span></div>
     
     
    -<div class="viewcode-block" id="Timer.stop"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer.stop">[docs]</a>    <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    +<div class="viewcode-block" id="Timer.stop"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.Timer.stop">[docs]</a>    <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">ender</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">ender</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
                 <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
                 <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
    @@ -193,7 +207,7 @@
             <span class="k">return</span> <span class="n">timer</span></div></div>
             <span class="k">return</span> <span class="n">timer</span></div></div>
     
     
     
     
    -<div class="viewcode-block" id="AverageMeter"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.AverageMeter">[docs]</a><span class="k">class</span> <span class="nc">AverageMeter</span><span class="p">:</span>
    +<span class="k">class</span> <span class="nc">AverageMeter</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;A class to calculate the average of a metric, for each batch</span>
         <span class="sd">&quot;&quot;&quot;A class to calculate the average of a metric, for each batch</span>
     <span class="sd">    during training/testing&quot;&quot;&quot;</span>
     <span class="sd">    during training/testing&quot;&quot;&quot;</span>
     
     
    @@ -201,7 +215,7 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">=</span> <span class="kc">None</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">=</span> <span class="mi">0</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">=</span> <span class="mi">0</span>
     
     
    -<div class="viewcode-block" id="AverageMeter.update"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.AverageMeter.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span cl
    +    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p
     
     
             <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
             <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
                 <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
                 <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
    @@ -211,20 +225,20 @@
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">+=</span> <span class="n">value</span> <span class="o">*</span> <span class="n">batch_size</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">+=</span> <span class="n">value</span> <span class="o">*</span> <span class="n">batch_size</span>
     
     
    -        <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">+=</span> <span class="n">batch_size</span></div>
    +        <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">+=</span> <span class="n">batch_size</span>
     
     
         <span class="nd">@property</span>
         <span class="nd">@property</span>
         <span class="k">def</span> <span class="nf">average</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">average</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                 <span class="k">return</span> <span class="mi">0</span>
                 <span class="k">return</span> <span class="mi">0</span>
             <span class="k">return</span> <span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="fm">__float__</span><span class="p">())</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="o">.</span><span class="
             <span class="k">return</span> <span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="fm">__float__</span><span class="p">())</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="o">.</span><span class="
    -            <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span></div>
    +            <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
     
     
             <span class="c1"># return (self._sum / self._count).__float__() if self._sum.dim() &lt; 1 or len(self._sum) == 1 \</span>
             <span class="c1"># return (self._sum / self._count).__float__() if self._sum.dim() &lt; 1 or len(self._sum) == 1 \</span>
             <span class="c1">#     else tuple((self._sum / self._count).cpu().numpy())</span>
             <span class="c1">#     else tuple((self._sum / self._count).cpu().numpy())</span>
     
     
     
     
    -<div class="viewcode-block" id="tensor_container_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.tensor_container_to_device">[docs]</a><span class="k">def</span> <span class="nf">tensor_container_to_device</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span 
    +<div class="viewcode-block" id="tensor_container_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.tensor_container_to_device">[docs]</a><span class="k">def</span> <span class="nf">tensor_container_to_device</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class=
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    recursively send compounded objects to device (sending all tensors to device and maintaining structure)</span>
     <span class="sd">    recursively send compounded objects to device (sending all tensors to device and maintaining structure)</span>
     <span class="sd">        :param obj           the object to send to device (list / tuple / tensor / dict)</span>
     <span class="sd">        :param obj           the object to send to device (list / tuple / tensor / dict)</span>
    @@ -245,7 +259,7 @@
             <span class="k">return</span> <span class="n">obj</span></div>
             <span class="k">return</span> <span class="n">obj</span></div>
     
     
     
     
    -<div class="viewcode-block" id="get_param"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_param">[docs]</a><span class="k">def</span> <span class="nf">get_param</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    +<div class="viewcode-block" id="get_param"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.get_param">[docs]</a><span class="k">def</span> <span class="nf">get_param</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val.</span>
     <span class="sd">    Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val.</span>
     <span class="sd">    In case the default_val is of type dictionary, and a value is found in the params - the function</span>
     <span class="sd">    In case the default_val is of type dictionary, and a value is found in the params - the function</span>
    @@ -279,23 +293,23 @@
             <span class="k">return</span> <span class="n">default_val</span></div>
             <span class="k">return</span> <span class="n">default_val</span></div>
     
     
     
     
    -<div class="viewcode-block" id="static_vars"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.static_vars">[docs]</a><span class="k">def</span> <span class="nf">static_vars</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">static_vars</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">decorate</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">decorate</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
             <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
                 <span class="nb">setattr</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
                 <span class="nb">setattr</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
             <span class="k">return</span> <span class="n">func</span>
             <span class="k">return</span> <span class="n">func</span>
     
     
    -    <span class="k">return</span> <span class="n">decorate</span></div>
    +    <span class="k">return</span> <span class="n">decorate</span>
     
     
     
     
    -<div class="viewcode-block" id="print_once"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.print_once">[docs]</a><span class="nd">@static_vars</span><span class="p">(</span><span class="n">printed</span><span class="o">=</span><span class="nb">set</span><span class="p">())</span>
    +<span class="nd">@static_vars</span><span class="p">(</span><span class="n">printed</span><span class="o">=</span><span class="nb">set</span><span class="p">())</span>
     <span class="k">def</span> <span class="nf">print_once</span><span class="p">(</span><span class="n">s</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">print_once</span><span class="p">(</span><span class="n">s</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
         <span class="k">if</span> <span class="n">s</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">s</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="p">:</span>
             <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
             <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
    -        <span class="nb">print</span><span class="p">(</span><span class="n">s</span><span class="p">)</span></div>
    +        <span class="nb">print</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="move_state_dict_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.move_state_dict_to_device">[docs]</a><span class="k">def</span> <span class="nf">move_state_dict_to_device</span><span class="p">(</span><span class="n">model_sd</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">move_state_dict_to_device</span><span class="p">(</span><span class="n">model_sd</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Moving model state dict tensors to target device (cuda or cpu)</span>
     <span class="sd">    Moving model state dict tensors to target device (cuda or cpu)</span>
     <span class="sd">    :param model_sd: model state dict</span>
     <span class="sd">    :param model_sd: model state dict</span>
    @@ -303,10 +317,10 @@
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
         <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_sd</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_sd</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
             <span class="n">model_sd</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
             <span class="n">model_sd</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    -    <span class="k">return</span> <span class="n">model_sd</span></div>
    +    <span class="k">return</span> <span class="n">model_sd</span>
     
     
     
     
    -<div class="viewcode-block" id="random_seed"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.random_seed">[docs]</a><span class="k">def</span> <span class="nf">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
    +<div class="viewcode-block" id="random_seed"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.random_seed">[docs]</a><span class="k">def</span> <span class="nf">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Sets random seed of numpy, torch and random.</span>
     <span class="sd">    Sets random seed of numpy, torch and random.</span>
     
     
    @@ -321,7 +335,7 @@
         <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)</span></div>
         <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)</span></div>
     
     
     
     
    -<div class="viewcode-block" id="load_func"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.load_func">[docs]</a><span class="k">def</span> <span class="nf">load_func</span><span class="p">(</span><span class="n">dotpath</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">load_func</span><span class="p">(</span><span class="n">dotpath</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    load function in module.  function is right-most segment.</span>
     <span class="sd">    load function in module.  function is right-most segment.</span>
     
     
    @@ -332,10 +346,10 @@
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
         <span class="n">module_</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="n">dotpath</span><span class="o">.</span><span class="n">rsplit</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="n">maxsplit</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
         <span class="n">module_</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="n">dotpath</span><span class="o">.</span><span class="n">rsplit</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="n">maxsplit</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
         <span class="n">m</span> <span class="o">=</span> <span class="n">import_module</span><span class="p">(</span><span class="n">module_</span><span class="p">)</span>
         <span class="n">m</span> <span class="o">=</span> <span class="n">import_module</span><span class="p">(</span><span class="n">module_</span><span class="p">)</span>
    -    <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">func</span><span class="p">)</span></div>
    +    <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">func</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="get_filename_suffix_by_framework"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_filename_suffix_by_framework">[docs]</a><span class="k">def</span> <span class="nf">get_filename_suffix_by_framework</span><span class="p">(</span><span class="n">framework</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    +<span class="k">def</span> <span class="nf">get_filename_suffix_by_framework</span><span class="p">(</span><span class="n">framework</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Return the file extension of framework.</span>
     <span class="sd">    Return the file extension of framework.</span>
     
     
    @@ -359,10 +373,10 @@
         <span class="k">if</span> <span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frameworks_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
         <span class="k">if</span> <span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frameworks_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unsupported framework: </span><span class="si">{</span><span class="n">framework</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unsupported framework: </span><span class="si">{</span><span class="n">framework</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
     
     
    -    <span class="k">return</span> <span class="n">frameworks_dict</span><span class="p">[</span><span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()]</span></div>
    +    <span class="k">return</span> <span class="n">frameworks_dict</span><span class="p">[</span><span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()]</span>
     
     
     
     
    -<div class="viewcode-block" id="check_models_have_same_weights"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.check_models_have_same_weights">[docs]</a><span class="k">def</span> <span class="nf">check_models_have_same_weights</span><span class="p">(</span><span class="n">model_1</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module
    +<span class="k">def</span> <span class="nf">check_models_have_same_weights</span><span class="p">(</span><span class="n">model_1</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">model_2</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Checks whether two networks have the same weights</span>
     <span class="sd">    Checks whether two networks have the same weights</span>
     
     
    @@ -382,10 +396,10 @@
         <span class="k">if</span> <span class="n">models_differ</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">models_differ</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
             <span class="k">return</span> <span class="kc">True</span>
             <span class="k">return</span> <span class="kc">True</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
    -        <span class="k">return</span> <span class="kc">False</span></div>
    +        <span class="k">return</span> <span class="kc">False</span>
     
     
     
     
    -<div class="viewcode-block" id="recursive_override"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.recursive_override">[docs]</a><span class="k">def</span> <span class="nf">recursive_override</span><span class="p">(</span><span class="n">base</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">extension</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span
    +<span class="k">def</span> <span class="nf">recursive_override</span><span class="p">(</span><span class="n">base</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">extension</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
         <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">extension</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">extension</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
             <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">base</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">base</span><span class="p">:</span>
                 <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">):</span>
                 <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">):</span>
    @@ -393,10 +407,10 @@
                 <span class="k">else</span><span class="p">:</span>
                 <span class="k">else</span><span class="p">:</span>
                     <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
                     <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
    -            <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span></div>
    +            <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
     
     
     
     
    -<div class="viewcode-block" id="download_and_unzip_from_url"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.download_and_unzip_from_url">[docs]</a><span class="k">def</span> <span class="nf">download_and_unzip_from_url</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">unzip</sp
    +<span class="k">def</span> <span class="nf">download_and_unzip_from_url</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">unzip</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">delete</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Downloads a zip file from url to dir, and unzips it.</span>
     <span class="sd">    Downloads a zip file from url to dir, and unzips it.</span>
     
     
    @@ -431,10 +445,10 @@
         <span class="nb">dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
         <span class="nb">dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
         <span class="nb">dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>  <span class="c1"># make directory</span>
         <span class="nb">dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>  <span class="c1"># make directory</span>
         <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="p">[</span><span class="n">url</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">))</span> <span class="k">else</span> <span class="n">url</span><span class="p">:</span
         <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="p">[</span><span class="n">url</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">))</span> <span class="k">else</span> <span class="n">url</span><span class="p">:</span
    -        <span class="n">download_one</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="nb">dir</span><span class="p">)</span></div>
    +        <span class="n">download_one</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="nb">dir</span><span class="p">)</span>
     
     
     
     
    -<div class="viewcode-block" id="download_and_untar_from_url"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.download_and_untar_from_url">[docs]</a><span class="k">def</span> <span class="nf">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">dir</span><spa
    +<span class="k">def</span> <span class="nf">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">dir</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]</span> <span class="o">=</span> <span cl
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Download a file from url and untar.</span>
     <span class="sd">    Download a file from url and untar.</span>
     
     
    @@ -460,10 +474,10 @@
             <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Extracting to </span><span class="si">{</span><span class="nb">dir</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
             <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Extracting to </span><span class="si">{</span><span class="nb">dir</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
             <span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filepath</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="n">modes</span><span class="p">[</span><span class="n">filepath</span><span class="o">.</span><span class="n">suffix</span><span class="p">])</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
             <span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filepath</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="n">modes</span><span class="p">[</span><span class="n">filepath</span><span class="o">.</span><span class="n">suffix</span><span class="p">])</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
                 <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
                 <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
    -        <span class="n">filepath</span><span class="o">.</span><span class="n">unlink</span><span class="p">()</span></div>
    +        <span class="n">filepath</span><span class="o">.</span><span class="n">unlink</span><span class="p">()</span>
     
     
     
     
    -<div class="viewcode-block" id="make_divisible"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.make_divisible">[docs]</a><span class="k">def</span> <span class="nf">make_divisible</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">divisor</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ce
    +<span class="k">def</span> <span class="nf">make_divisible</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">divisor</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ceil</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    Returns x evenly divisible by divisor.</span>
     <span class="sd">    Returns x evenly divisible by divisor.</span>
     <span class="sd">    If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.</span>
     <span class="sd">    If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.</span>
    @@ -471,10 +485,10 @@
         <span class="k">if</span> <span class="n">ceil</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">ceil</span><span class="p">:</span>
             <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span>
             <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
    -        <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span></div>
    +        <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span>
     
     
     
     
    -<div class="viewcode-block" id="check_img_size_divisibility"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.check_img_size_divisibility">[docs]</a><span class="k">def</span> <span class="nf">check_img_size_divisibility</span><span class="p">(</span><span class="n">img_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</spa
    +<span class="k">def</span> <span class="nf">check_img_size_divisibility</span><span class="p">(</span><span class="n">img_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">bool</span><span class="p">,</spa
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">    :param img_size: Int, the size of the image (H or W).</span>
     <span class="sd">    :param img_size: Int, the size of the image (H or W).</span>
     <span class="sd">    :param stride: Int, the number to check if img_size is divisible by.</span>
     <span class="sd">    :param stride: Int, the number to check if img_size is divisible by.</span>
    @@ -486,22 +500,22 @@
         <span class="k">if</span> <span class="n">new_size</span> <span class="o">!=</span> <span class="n">img_size</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">new_size</span> <span class="o">!=</span> <span class="n">img_size</span><span class="p">:</span>
             <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="p">(</span><span class="n">new_size</span><span class="p">,</span> <span class="n">make_divisible</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">stride</span><span class="p">),</span> <span class="n">ceil</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span
             <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="p">(</span><span class="n">new_size</span><span class="p">,</span> <span class="n">make_divisible</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">stride</span><span class="p">),</span> <span class="n">ceil</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
    -        <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">None</span></div>
    +        <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">None</span>
     
     
     
     
    -<div class="viewcode-block" id="get_orientation_key"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_orientation_key">[docs]</a><span class="nd">@lru_cache</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
    +<span class="nd">@lru_cache</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
     <span class="k">def</span> <span class="nf">get_orientation_key</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
     <span class="k">def</span> <span class="nf">get_orientation_key</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;Get the orientation key according to PIL, which is useful to get the image size for instance</span>
         <span class="sd">&quot;&quot;&quot;Get the orientation key according to PIL, which is useful to get the image size for instance</span>
     <span class="sd">    :return: Orientation key according to PIL&quot;&quot;&quot;</span>
     <span class="sd">    :return: Orientation key according to PIL&quot;&quot;&quot;</span>
         <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">ExifTags</span><span class="o">.</span><span class="n">TAGS</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">ExifTags</span><span class="o">.</span><span class="n">TAGS</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
             <span class="k">if</span> <span class="n">value</span> <span class="o">==</span> <span class="s1">&#39;Orientation&#39;</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">value</span> <span class="o">==</span> <span class="s1">&#39;Orientation&#39;</span><span class="p">:</span>
    -            <span class="k">return</span> <span class="n">key</span></div>
    +            <span class="k">return</span> <span class="n">key</span>
     
     
     
     
    -<div class="viewcode-block" id="exif_size"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.exif_size">[docs]</a><span class="k">def</span> <span class="nf">exif_size</span><span class="p">(</span><span class="n">image</span><span class="p">:</span> <span class="n">Image</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <s
    +<span class="k">def</span> <span class="nf">exif_size</span><span class="p">(</span><span class="n">image</span><span class="p">:</span> <span class="n">Image</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
         <span class="sd">&quot;&quot;&quot;Get the size of image.</span>
         <span class="sd">&quot;&quot;&quot;Get the size of image.</span>
     <span class="sd">    :param image:   The image to get size from</span>
     <span class="sd">    :param image:   The image to get size from</span>
    -<span class="sd">    :return:        (width, height)</span>
    +<span class="sd">    :return:        (height, width)</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     <span class="sd">    &quot;&quot;&quot;</span>
     
     
         <span class="n">orientation_key</span> <span class="o">=</span> <span class="n">get_orientation_key</span><span class="p">()</span>
         <span class="n">orientation_key</span> <span class="o">=</span> <span class="n">get_orientation_key</span><span class="p">()</span>
    @@ -519,14 +533,27 @@
                     <span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
                     <span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
         <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
         <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
             <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Caught Exception trying to rotate: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">))</span>
             <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Caught Exception trying to rotate: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">))</span>
    -    <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">image_size</span>
    -    <span class="k">return</span> <span class="n">width</span><span class="p">,</span> <span class="n">height</span></div>
    +    <span class="n">width</span><span class="p">,</span> <span class="n">height</span> <span class="o">=</span> <span class="n">image_size</span>
    +    <span class="k">return</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span>
     
     
     
     
    -<div class="viewcode-block" id="get_image_size_from_path"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_image_size_from_path">[docs]</a><span class="k">def</span> <span class="nf">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span c
    +<span class="k">def</span> <span class="nf">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
         <span class="sd">&quot;&quot;&quot;Get the image size of an image at a specific path&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;Get the image size of an image at a specific path&quot;&quot;&quot;</span>
         <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">img_path</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
         <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">img_path</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
    -        <span class="k">return</span> <span class="n">exif_size</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">f</span><span class="p">))</span></div>
    +        <span class="k">return</span> <span class="n">exif_size</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">f</span><span class="p">))</span>
    +
    +
    +<span class="k">def</span> <span class="nf">override_default_params_without_nones</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">default_params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">:</span>
    +    <span class="sd">&quot;&quot;&quot;</span>
    +<span class="sd">    Helper method for overriding default dictionary&#39;s entries excluding entries with None values.</span>
    +<span class="sd">    :param params: dict, output dictionary which will take the defaults.</span>
    +<span class="sd">    :param default_params: dict, dictionary for the defaults.</span>
    +<span class="sd">    :return: dict, params after manipulation,</span>
    +<span class="sd">    &quot;&quot;&quot;</span>
    +    <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">default_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
    +        <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="ow">or</span> <span class="n">params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
    +            <span class="n">params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
    +    <span class="k">return</span> <span class="n">params</span>
     </pre></div>
     </pre></div>
     
     
                </div>
                </div>
    @@ -556,4 +583,4 @@
       </script> 
       </script> 
     
     
     </body>
     </body>
    -</html>
    +</html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.version_utils &mdash; SuperGradients 3.0.3 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
    11. <!--[if lt IE 9]>
    12. <script src="../../../../_static/js/html5shiv.min.js"></script>
    13. <![endif]-->
    14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    15. <script src="../../../../_static/jquery.js"></script>
    16. <script src="../../../../_static/underscore.js"></script>
    17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    18. <script src="../../../../_static/doctools.js"></script>
    19. <script src="../../../../_static/sphinx_highlight.js"></script>
    20. <script src="../../../../_static/js/theme.js"></script>
    21. <link rel="index" title="Index" href="../../../../genindex.html" />
    22. <link rel="search" title="Search" href="../../../../search.html" />
    23. </head>
    24. <body class="wy-body-for-nav">
    25. <div class="wy-grid-for-nav">
    26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    27. <div class="wy-side-scroll">
    28. <div class="wy-side-nav-search" >
    29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    30. </a>
    31. <div role="search">
    32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    33. <input type="text" name="q" placeholder="Search docs" />
    34. <input type="hidden" name="check_keywords" value="yes" />
    35. <input type="hidden" name="area" value="default" />
    36. </form>
    37. </div>
    38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
    40. <ul>
    41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
    42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
    45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
    46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
    47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
    57. </ul>
    58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
    59. <ul>
    60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    62. </ul>
    63. </div>
    64. </div>
    65. </nav>
    66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    68. <a href="../../../../index.html">SuperGradients</a>
    69. </nav>
    70. <div class="wy-nav-content">
    71. <div class="rst-content">
    72. <div role="navigation" aria-label="Page navigation">
    73. <ul class="wy-breadcrumbs">
    74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    76. <li>super_gradients.training.utils.version_utils</li>
    77. <li class="wy-breadcrumbs-aside">
    78. </li>
    79. </ul>
    80. <hr/>
    81. </div>
    82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    83. <div itemprop="articleBody">
    84. <h1>Source code for super_gradients.training.utils.version_utils</h1><div class="highlight"><pre>
    85. <span></span><span class="kn">import</span> <span class="nn">torch</span>
    86. <span class="n">_TORCH_VERSION_MAJOR_MINOR_</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">version</span><span class="o">.</span><span class="n">__version__</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[:</span><span class="mi">2</span><span class="p">]))</span>
    87. <span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;torch_version_is_greater_or_equal&quot;</span><span class="p">]</span>
    88. <div class="viewcode-block" id="torch_version_is_greater_or_equal"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.torch_version_is_greater_or_equal">[docs]</a><span class="k">def</span> <span class="nf">torch_version_is_greater_or_equal</span><span class="p">(</span><span class="n">major</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">minor</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
    89. <span class="n">version</span> <span class="o">=</span> <span class="p">(</span><span class="n">major</span><span class="p">,</span> <span class="n">minor</span><span class="p">)</span>
    90. <span class="k">return</span> <span class="n">_TORCH_VERSION_MAJOR_MINOR_</span> <span class="o">&gt;=</span> <span class="n">version</span></div>
    91. </pre></div>
    92. </div>
    93. </div>
    94. <footer>
    95. <hr/>
    96. <div role="contentinfo">
    97. <p>&#169; Copyright 2021, SuperGradients team.</p>
    98. </div>
    99. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    100. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    101. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    102. </footer>
    103. </div>
    104. </div>
    105. </section>
    106. </div>
    107. <script>
    108. jQuery(function () {
    109. SphinxRtdTheme.Navigation.enable(true);
    110. });
    111. </script>
    112. </body>
    113. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    1. <!DOCTYPE html>
    2. <html class="writer-html5" lang="en" >
    3. <head>
    4. <meta charset="utf-8" />
    5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    6. <title>super_gradients.training.utils.weight_averaging_utils &mdash; SuperGradients 1.0 documentation</title>
    7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
    9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
    10. <!--[if lt IE 9]>
    11. <script src="../../../../_static/js/html5shiv.min.js"></script>
    12. <![endif]-->
    13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
    14. <script src="../../../../_static/jquery.js"></script>
    15. <script src="../../../../_static/underscore.js"></script>
    16. <script src="../../../../_static/doctools.js"></script>
    17. <script src="../../../../_static/js/theme.js"></script>
    18. <link rel="index" title="Index" href="../../../../genindex.html" />
    19. <link rel="search" title="Search" href="../../../../search.html" />
    20. </head>
    21. <body class="wy-body-for-nav">
    22. <div class="wy-grid-for-nav">
    23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
    24. <div class="wy-side-scroll">
    25. <div class="wy-side-nav-search" >
    26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
    27. </a>
    28. <div role="search">
    29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    30. <input type="text" name="q" placeholder="Search docs" />
    31. <input type="hidden" name="check_keywords" value="yes" />
    32. <input type="hidden" name="area" value="default" />
    33. </form>
    34. </div>
    35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
    36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
    37. <ul>
    38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
    39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
    40. </ul>
    41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
    42. <ul>
    43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
    44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
    45. </ul>
    46. <p class="caption"><span class="caption-text">User Guide</span></p>
    47. <ul>
    48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
    49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
    50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
    51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
    52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
    53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
    54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
    55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
    56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
    57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
    58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
    59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
    60. </ul>
    61. </div>
    62. </div>
    63. </nav>
    64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
    65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
    66. <a href="../../../../index.html">SuperGradients</a>
    67. </nav>
    68. <div class="wy-nav-content">
    69. <div class="rst-content">
    70. <div role="navigation" aria-label="Page navigation">
    71. <ul class="wy-breadcrumbs">
    72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
    73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
    74. <li>super_gradients.training.utils.weight_averaging_utils</li>
    75. <li class="wy-breadcrumbs-aside">
    76. </li>
    77. </ul>
    78. <hr/>
    79. </div>
    80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
    81. <div itemprop="articleBody">
    82. <h1>Source code for super_gradients.training.utils.weight_averaging_utils</h1><div class="highlight"><pre>
    83. <span></span><span class="kn">import</span> <span class="nn">os</span>
    84. <span class="kn">import</span> <span class="nn">torch</span>
    85. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
    86. <span class="kn">import</span> <span class="nn">pkg_resources</span>
    87. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
    88. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">move_state_dict_to_device</span>
    89. <div class="viewcode-block" id="ModelWeightAveraging"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging">[docs]</a><span class="k">class</span> <span class="nc">ModelWeightAveraging</span><span class="p">:</span>
    90. <span class="sd">&quot;&quot;&quot;</span>
    91. <span class="sd"> Utils class for managing the averaging of the best several snapshots into a single model.</span>
    92. <span class="sd"> A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when</span>
    93. <span class="sd"> training is completed. The snapshot file will only be deleted upon completing the training.</span>
    94. <span class="sd"> The snapshot dict will be managed on cpu.</span>
    95. <span class="sd"> &quot;&quot;&quot;</span>
    96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ckpt_dir</span><span class="p">,</span>
    97. <span class="n">greater_is_better</span><span class="p">,</span>
    98. <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metric_to_watch</span><span class="o">=</span><span class="s1">&#39;acc&#39;</span><span class="p">,</span>
    99. <span class="n">metric_idx</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">load_checkpoint</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    100. <span class="n">number_of_models_to_average</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
    101. <span class="n">model_checkpoints_location</span><span class="o">=</span><span class="s1">&#39;local&#39;</span>
    102. <span class="p">):</span>
    103. <span class="sd">&quot;&quot;&quot;</span>
    104. <span class="sd"> Init the ModelWeightAveraging</span>
    105. <span class="sd"> :param checkpoint_dir: the directory where the checkpoints are saved</span>
    106. <span class="sd"> :param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model</span>
    107. <span class="sd"> :param metric_idx:</span>
    108. <span class="sd"> :param load_checkpoint: whether to load pre-existing snapshot dict.</span>
    109. <span class="sd"> :param number_of_models_to_average: number of models to average</span>
    110. <span class="sd"> &quot;&quot;&quot;</span>
    111. <span class="k">if</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    112. <span class="n">source_ckpt_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span> <span class="s1">&#39;averaging_snapshots.pkl&#39;</span><span class="p">)</span>
    113. <span class="n">source_ckpt_file</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="n">source_ckpt_file</span><span class="p">)</span>
    114. <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_dir</span><span class="p">,</span> <span class="s1">&#39;averaging_snapshots.pkl&#39;</span><span class="p">)</span>
    115. <span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span> <span class="o">=</span> <span class="n">number_of_models_to_average</span>
    116. <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="n">metric_to_watch</span>
    117. <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span> <span class="o">=</span> <span class="n">metric_idx</span>
    118. <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="n">greater_is_better</span>
    119. <span class="c1"># if continuing training, copy previous snapshot dict if exist</span>
    120. <span class="k">if</span> <span class="n">load_checkpoint</span> <span class="ow">and</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">source_ckpt_file</span><span class="p">):</span>
    121. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="n">ckpt_destination_dir</span><span class="o">=</span><span class="n">ckpt_dir</span><span class="p">,</span>
    122. <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span>
    123. <span class="n">ckpt_filename</span><span class="o">=</span><span class="s2">&quot;averaging_snapshots.pkl&quot;</span><span class="p">,</span>
    124. <span class="n">load_weights_only</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    125. <span class="n">model_checkpoints_location</span><span class="o">=</span><span class="n">model_checkpoints_location</span><span class="p">,</span>
    126. <span class="n">overwrite_local_ckpt</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    127. <span class="k">else</span><span class="p">:</span>
    128. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">):</span> <span class="kc">None</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)}</span>
    129. <span class="c1"># if metric to watch is acc, hold a zero array, if loss hold inf array</span>
    130. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
    131. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)</span>
    132. <span class="k">else</span><span class="p">:</span>
    133. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)</span>
    134. <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span>
    135. <div class="viewcode-block" id="ModelWeightAveraging.update_snapshots_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.update_snapshots_dict">[docs]</a> <span class="k">def</span> <span class="nf">update_snapshots_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">):</span>
    136. <span class="sd">&quot;&quot;&quot;</span>
    137. <span class="sd"> Update the snapshot dict and returns the updated average model for saving</span>
    138. <span class="sd"> :param model: the latest model</span>
    139. <span class="sd"> :param validation_results_tuple: performance of the latest model</span>
    140. <span class="sd"> &quot;&quot;&quot;</span>
    141. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_averaging_snapshots_dict</span><span class="p">()</span>
    142. <span class="c1"># IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE</span>
    143. <span class="n">require_update</span><span class="p">,</span> <span class="n">update_ind</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_is_better</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">)</span>
    144. <span class="k">if</span> <span class="n">require_update</span><span class="p">:</span>
    145. <span class="c1"># moving state dict to cpu</span>
    146. <span class="n">new_sd</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
    147. <span class="n">new_sd</span> <span class="o">=</span> <span class="n">move_state_dict_to_device</span><span class="p">(</span><span class="n">new_sd</span><span class="p">,</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
    148. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">update_ind</span><span class="p">)]</span> <span class="o">=</span> <span class="n">new_sd</span>
    149. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">][</span><span class="n">update_ind</span><span class="p">]</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span><span class="p">]</span>
    150. <span class="k">return</span> <span class="n">averaging_snapshots_dict</span></div>
    151. <div class="viewcode-block" id="ModelWeightAveraging.get_average_model"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.get_average_model">[docs]</a> <span class="k">def</span> <span class="nf">get_average_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    152. <span class="sd">&quot;&quot;&quot;</span>
    153. <span class="sd"> Returns the averaged model</span>
    154. <span class="sd"> :param model: will be used to determine arch</span>
    155. <span class="sd"> :param validation_results_tuple: if provided, will update the average model before returning</span>
    156. <span class="sd"> :param target_device: if provided, return sd on target device</span>
    157. <span class="sd"> &quot;&quot;&quot;</span>
    158. <span class="c1"># If validation tuple is provided, update the average model</span>
    159. <span class="k">if</span> <span class="n">validation_results_tuple</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    160. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_snapshots_dict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">)</span>
    161. <span class="k">else</span><span class="p">:</span>
    162. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_averaging_snapshots_dict</span><span class="p">()</span>
    163. <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span>
    164. <span class="n">average_model_sd</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot0&#39;</span><span class="p">]</span>
    165. <span class="k">for</span> <span class="n">n_model</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">):</span>
    166. <span class="k">if</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">n_model</span><span class="p">)]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    167. <span class="n">net_sd</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">n_model</span><span class="p">)]</span>
    168. <span class="c1"># USING MOVING AVERAGE</span>
    169. <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">average_model_sd</span><span class="p">:</span>
    170. <span class="n">average_model_sd</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">true_divide</span><span class="p">(</span>
    171. <span class="n">average_model_sd</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">*</span> <span class="n">n_model</span> <span class="o">+</span> <span class="n">net_sd</span><span class="p">[</span><span class="n">key</span><span class="p">],</span>
    172. <span class="p">(</span><span class="n">n_model</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
    173. <span class="k">return</span> <span class="n">average_model_sd</span></div>
    174. <div class="viewcode-block" id="ModelWeightAveraging.cleanup"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.cleanup">[docs]</a> <span class="k">def</span> <span class="nf">cleanup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    175. <span class="sd">&quot;&quot;&quot;</span>
    176. <span class="sd"> Delete snapshot file when reaching the last epoch</span>
    177. <span class="sd"> &quot;&quot;&quot;</span>
    178. <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span></div>
    179. <span class="k">def</span> <span class="nf">_is_better</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">):</span>
    180. <span class="sd">&quot;&quot;&quot;</span>
    181. <span class="sd"> Determines if the new model is better according to the specified metrics</span>
    182. <span class="sd"> :param averaging_snapshots_dict: snapshot dict</span>
    183. <span class="sd"> :param validation_results_tuple: latest model performance</span>
    184. <span class="sd"> &quot;&quot;&quot;</span>
    185. <span class="n">snapshot_metric_array</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">]</span>
    186. <span class="n">val</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span><span class="p">]</span>
    187. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
    188. <span class="n">update_ind</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">snapshot_metric_array</span><span class="p">)</span>
    189. <span class="k">else</span><span class="p">:</span>
    190. <span class="n">update_ind</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">snapshot_metric_array</span><span class="p">)</span>
    191. <span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">and</span> <span class="n">val</span> <span class="o">&gt;</span> <span class="n">snapshot_metric_array</span><span class="p">[</span><span class="n">update_ind</span><span class="p">])</span> <span class="ow">or</span> <span class="p">(</span>
    192. <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">and</span> <span class="n">val</span> <span class="o">&lt;</span> <span class="n">snapshot_metric_array</span><span class="p">[</span><span class="n">update_ind</span><span class="p">]):</span>
    193. <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="n">update_ind</span>
    194. <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">None</span>
    195. <span class="k">def</span> <span class="nf">_get_averaging_snapshots_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    196. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span></div>
    197. </pre></div>
    198. </div>
    199. </div>
    200. <footer>
    201. <hr/>
    202. <div role="contentinfo">
    203. <p>&#169; Copyright 2021, SuperGradients team.</p>
    204. </div>
    205. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    206. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    207. provided by <a href="https://readthedocs.org">Read the Docs</a>.
    208. </footer>
    209. </div>
    210. </div>
    211. </section>
    212. </div>
    213. <script>
    214. jQuery(function () {
    215. SphinxRtdTheme.Navigation.enable(true);
    216. });
    217. </script>
    218. </body>
    219. </html>
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.abstractions
    2. ====================================
    3. .. automodule:: super_gradients.common.abstractions
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.auto\_logging
    2. =====================================
    3. .. automodule:: super_gradients.common.auto_logging
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.aws\_connection
    2. =======================================
    3. .. automodule:: super_gradients.common.aws_connection
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.data\_connection
    2. ========================================
    3. .. automodule:: super_gradients.common.data_connection
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.data\_interface
    2. =======================================
    3. .. automodule:: super_gradients.common.data_interface
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.data\_types
    2. ===================================
    3. .. automodule:: super_gradients.common.data_types
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.decorators
    2. ==================================
    3. .. automodule:: super_gradients.common.decorators
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1. super\_gradients.common.environment
    2. ===================================
    3. .. automodule:: super_gradients.common.environment
    Discard
    @@ -20,10 +20,8 @@ Welcome to SuperGradients's documentation!
        super_gradients.training
        super_gradients.training
     
     
     .. toctree::
     .. toctree::
    -   :maxdepth: 4
    -   :caption: User Guide
    -
    -   user_guide
    +.. :maxdepth: 4
    +.. :caption: User Guide
     
     
     Indices and tables
     Indices and tables
     ==================
     ==================
    Discard
    1
    2
    3
    4
    5
    6
    7
    1. super_gradients
    2. ===============
    3. .. toctree::
    4. :maxdepth: 8
    5. super_gradients
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    1. super\_gradients.common.abstractions package
    2. ============================================
    3. Submodules
    4. ----------
    5. super\_gradients.common.abstractions.abstract\_logger module
    6. ------------------------------------------------------------
    7. .. automodule:: super_gradients.common.abstractions.abstract_logger
    8. :members:
    9. :undoc-members:
    10. :show-inheritance:
    11. Module contents
    12. ---------------
    13. .. automodule:: super_gradients.common.abstractions
    14. :members:
    15. :undoc-members:
    16. :show-inheritance:
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    1. super\_gradients.common.auto\_logging package
    2. =============================================
    3. Submodules
    4. ----------
    5. super\_gradients.common.auto\_logging.auto\_logger module
    6. ---------------------------------------------------------
    7. .. automodule:: super_gradients.common.auto_logging.auto_logger
    8. :members:
    9. :undoc-members:
    10. :show-inheritance:
    11. Module contents
    12. ---------------
    13. .. automodule:: super_gradients.common.auto_logging
    14. :members:
    15. :undoc-members:
    16. :show-inheritance:
    Discard

    Some files were not shown because too many files changed in this diff