Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#875 Feature/sg 761 yolo nas

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-761-yolo-nas
60 changed files with 2549 additions and 132 deletions
  1. 144
    29
      .circleci/config.yml
  2. 16
    0
      LICENSE.YOLONAS.md
  3. 12
    4
      README.md
  4. 97
    0
      YOLONAS.md
  5. 70
    0
      documentation/source/YoloNASQuickstart.md
  6. BIN
      documentation/source/images/messi_penalty_pred_higher_conf_right_class2.gif
  7. BIN
      documentation/source/images/soccer.png
  8. BIN
      documentation/source/images/yolo_nas_frontier.png
  9. BIN
      documentation/source/images/yolo_nas_predict_demo.png
  10. BIN
      documentation/source/images/yolo_nas_qs_predict.png
  11. BIN
      documentation/source/images/yolo_nas_rf100.png
  12. 378
    0
      documentation/source/qat_ptq_yolo_nas.md
  13. 15
    0
      scripts/Dockerfile.branch
  14. 10
    0
      scripts/Dockerfile.branch.code
  15. 5
    2
      src/super_gradients/common/object_names.py
  16. 2
    2
      src/super_gradients/examples/predict/detection_predict.py
  17. 2
    2
      src/super_gradients/examples/predict/detection_predict_image_folder.py
  18. 2
    2
      src/super_gradients/examples/predict/detection_predict_streaming.py
  19. 2
    2
      src/super_gradients/examples/predict/detection_predict_video.py
  20. 2
    2
      src/super_gradients/module_interfaces/__init__.py
  21. 28
    0
      src/super_gradients/module_interfaces/module_interfaces.py
  22. 27
    0
      src/super_gradients/modules/__init__.py
  23. 27
    0
      src/super_gradients/modules/base_modules.py
  24. 19
    26
      src/super_gradients/modules/detection_modules.py
  25. 47
    0
      src/super_gradients/modules/head_replacement_utils.py
  26. 1
    1
      src/super_gradients/modules/pose_estimation_modules.py
  27. 112
    0
      src/super_gradients/recipes/arch_params/yolo_nas_l_arch_params.yaml
  28. 112
    0
      src/super_gradients/recipes/arch_params/yolo_nas_m_arch_params.yaml
  29. 112
    0
      src/super_gradients/recipes/arch_params/yolo_nas_s_arch_params.yaml
  30. 43
    0
      src/super_gradients/recipes/coco2017_yolo_nas_s.yaml
  31. 6
    5
      src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml
  32. 25
    8
      src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml
  33. 92
    0
      src/super_gradients/recipes/roboflow_yolo_nas_m.yaml
  34. 92
    0
      src/super_gradients/recipes/roboflow_yolo_nas_s.yaml
  35. 18
    0
      src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml
  36. 56
    0
      src/super_gradients/recipes/training_hyperparams/coco2017_yolo_nas_train_params.yaml
  37. 4
    4
      src/super_gradients/training/dataloaders/__init__.py
  38. 6
    6
      src/super_gradients/training/dataloaders/dataloaders.py
  39. 2
    2
      src/super_gradients/training/datasets/detection_datasets/roboflow/metadata.py
  40. 27
    9
      src/super_gradients/training/models/__init__.py
  41. 19
    7
      src/super_gradients/training/models/detection_models/csp_darknet53.py
  42. 4
    0
      src/super_gradients/training/models/detection_models/customizable_detector.py
  43. 2
    2
      src/super_gradients/training/models/detection_models/pp_yolo_e/__init__.py
  44. 4
    4
      src/super_gradients/training/models/detection_models/pp_yolo_e/pan.py
  45. 3
    2
      src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py
  46. 2
    1
      src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py
  47. 26
    0
      src/super_gradients/training/models/detection_models/yolo_nas/__init__.py
  48. 270
    0
      src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py
  49. 64
    0
      src/super_gradients/training/models/detection_models/yolo_nas/panneck.py
  50. 90
    0
      src/super_gradients/training/models/detection_models/yolo_nas/yolo_nas_variants.py
  51. 332
    0
      src/super_gradients/training/models/detection_models/yolo_nas/yolo_stages.py
  52. 7
    2
      src/super_gradients/training/pipelines/pipelines.py
  53. 4
    0
      src/super_gradients/training/pretrained_models.py
  54. 6
    6
      src/super_gradients/training/processing/processing.py
  55. 10
    0
      src/super_gradients/training/utils/checkpoint_utils.py
  56. 2
    1
      tests/deci_core_integration_test_suite_runner.py
  57. 2
    0
      tests/deci_core_unit_test_suite_runner.py
  58. 2
    1
      tests/integration_tests/__init__.py
  59. 48
    0
      tests/integration_tests/yolo_nas_integration_test.py
  60. 39
    0
      tests/unit_tests/replace_head_test.py
@@ -1,6 +1,12 @@
 version: 2.1
 
 parameters:
+  ad_hoc_container_build:
+    type: boolean
+    default: false
+  ad_hoc_container_build_code_only:
+    type: boolean
+    default: false
   remote_docker_version:
     type: string
     description: remote docker version
@@ -12,7 +18,7 @@ parameters:
   orb_version:
     type: string
     description: Deci ai ORB version https://circleci.com/developer/orbs/orb/deci-ai/circleci-common-orb
-    default: "10.5.0"
+    default: "10.5.1"
 #    default: "dev:alpha"
 
 orbs:
@@ -42,6 +48,47 @@ release_candidate_tag_filter: &release_candidate_tag_filter
       only: /^\d+\.\d+\.\d+rc\d+/
 
 commands:
+  build_and_publish_command:
+    parameters:
+      repo_name:
+        type: string
+      docker_context:
+        type: string
+      image_tag:
+        type: string
+      additional_tags:
+        type: string
+      build_args:
+        type: string
+        default: ""
+      dockerfile:
+        type: string
+        default: "Dockerfile"
+    steps:
+      - checkout
+      - attach_workspace:
+          at: ~/
+      - run:
+          name: Put config dir in repo context
+          command: |
+            if [ -d ~/.config ]; then
+              echo "found a .config directory, copying to repo dir"
+              cp -r  ~/.config ~/project/<< parameters.docker_context >>
+            fi
+      - deci-common/ecr_login_dev
+      - deci-common/container_image_build:
+          context: << parameters.docker_context >>
+          working_directory: "."
+          repository_name: << parameters.repo_name >>
+          image_tag: << parameters.image_tag >>
+          dockerfile: << parameters.dockerfile >>
+          build_args: << parameters.build_args >>
+      #          build_args: "PYTHON_VERSION=3.8 SG_VERSION=3.0.7"
+      - deci-common/push_docker_image_aws_dev:
+          repository_name: << parameters.repo_name >>
+          image_tag: << parameters.image_tag >>
+          additional_tags: << parameters.additional_tags >>
+
   get_beta_and_rc_tags:
     description: "getting beta and rc tag (if exist) according to ouir convention"
     steps:
@@ -607,8 +654,60 @@ jobs:
           command: "rm -r << parameters.sg_new_env_name >>"
           when: on_fail
 
+  docker-build-and-publish-branch:
+    docker:
+      - image: cimg/base:stable-20.04
+    parameters:
+      repo_name:
+        type: string
+        default: "deci/super-gradients"
+      docker_context:
+        type: string
+        default: "."
+      additional_tags:
+        type: string
+        default: ''
+    steps:
+      - setup_remote_docker:
+          version: 20.10.7
+          docker_layer_caching: true
+      - deci-common/container_image_lint_tag:
+          image_tag: "${CIRCLE_BRANCH}"
+      - run:
+          command: |
+            ADDITIONAL_TAGS="<< parameters.additional_tags >>"
+            echo "export ADDITIONAL_TAGS=${ADDITIONAL_TAGS}" >> $BASH_ENV
+      - run:
+          command: |
+            source $BASH_ENV
+            echo "$CONTAINER_LINT_TAG"
+            echo "$ADDITIONAL_TAGS"
+      - when:
+          condition: << pipeline.parameters.ad_hoc_container_build_code_only >>
+          steps:
+            - build_and_publish_command:
+                repo_name: << parameters.repo_name >>
+                docker_context: << parameters.docker_context >>
+                image_tag: $CONTAINER_LINT_TAG
+                additional_tags: $ADDITIONAL_TAGS
+                dockerfile: 'scripts/Dockerfile.branch.code'
+                build_args: "BASE_TAG=$CONTAINER_LINT_TAG"
+      - unless:
+          condition: << pipeline.parameters.ad_hoc_container_build_code_only >>
+          steps:
+            - build_and_publish_command:
+                repo_name: << parameters.repo_name >>
+                docker_context: << parameters.docker_context >>
+                image_tag: $CONTAINER_LINT_TAG
+                additional_tags: $ADDITIONAL_TAGS
+                dockerfile: 'scripts/Dockerfile.branch'
+
 workflows:
   release:
+    unless:
+      or:
+        - << pipeline.parameters.ad_hoc_container_build >>
+        - << pipeline.parameters.ad_hoc_container_build_code_only >>
     jobs:
       - deci-common/persist_version_info:
           version_override: $CIRCLE_TAG
@@ -670,6 +769,10 @@ workflows:
 
 
   build_and_deploy:
+    unless:
+      or:
+        - << pipeline.parameters.ad_hoc_container_build >>
+        - << pipeline.parameters.ad_hoc_container_build_code_only >>
     jobs:
       - deci-common/persist_version_info:
           use_rc: true
@@ -695,42 +798,46 @@ workflows:
           <<: *release_candidate_filter
 
   SG_docker:
-     jobs:
-       - change_rc_to_b: # works on release candidate creation
-           <<: *release_candidate_tag_filter
-       - build_and_publish_sg_container:  # works on release candidate creation
-           requires:
-             - "change_rc_to_b"
-           <<: *release_candidate_tag_filter
-       - testing_supergradients_docker_image:  # works on release candidate creation
+    unless:
+      or:
+        - << pipeline.parameters.ad_hoc_container_build >>
+        - << pipeline.parameters.ad_hoc_container_build_code_only >>
+    jobs:
+      - change_rc_to_b: # works on release candidate creation
+          <<: *release_candidate_tag_filter
+      - build_and_publish_sg_container:  # works on release candidate creation
+          requires:
+            - "change_rc_to_b"
+          <<: *release_candidate_tag_filter
+      - testing_supergradients_docker_image:  # works on release candidate creation
           image_repo: '307629990626.dkr.ecr.us-east-1.amazonaws.com/deci/super-gradients'
           requires:
             - "build_and_publish_sg_container"
             - "change_rc_to_b"
           <<: *release_candidate_tag_filter
-       - add_rc_tag_to_beta: # works on release candidate creation for ECR Repo
+      - add_rc_tag_to_beta: # works on release candidate creation for ECR Repo
           requires:
             - "testing_supergradients_docker_image"
             - "change_rc_to_b"
           <<: *release_candidate_tag_filter
-       - find_rc_tag_per_sha: # works on release
-           <<: *release_tag_filter
-       - add_release_tag_to_rc: # works on release
-            requires:
-              - "find_rc_tag_per_sha"
-            <<: *release_tag_filter
-       - slack/on-hold:
-           context: slack
-           channel: "sg-integration-tests"
-           requires:
-             - "add_release_tag_to_rc"
-           <<: *release_tag_filter
-       - hold-sg-public-release:  # works on release
-           type: approval
-           requires:
-             - "slack/on-hold"
-           <<: *release_tag_filter
-       - docker/publish:  # works on release
+      - find_rc_tag_per_sha: # works on release
+          <<: *release_tag_filter
+      - add_release_tag_to_rc: # works on release
+          requires:
+            - "find_rc_tag_per_sha"
+          <<: *release_tag_filter
+      - slack/on-hold:
+          context: slack
+          channel: "sg-integration-tests"
+          requires:
+           - "add_release_tag_to_rc"
+          <<: *release_tag_filter
+      - hold-sg-public-release:  # works on release
+          type: approval
+          requires:
+           - "slack/on-hold"
+          <<: *release_tag_filter
+      - docker/publish:  # works on release
           executor:
               image: cimg/base
               tag: stable-20.04
@@ -748,7 +855,7 @@ workflows:
           requires:
             - "hold-sg-public-release"
           <<: *release_tag_filter
-       - docker/publish: # works on release
+      - docker/publish: # works on release
           executor:
               image: cimg/base
               tag: stable-20.04
@@ -765,3 +872,11 @@ workflows:
           requires:
             - "hold-sg-public-release"
           <<: *release_tag_filter
+  build-and-push-container-flow:
+    when: << pipeline.parameters.ad_hoc_container_build >>
+    jobs:
+      - docker-build-and-publish-branch
+  build-and-push-container-code-only-flow:
+    when: << pipeline.parameters.ad_hoc_container_build_code_only >>
+    jobs:
+      - docker-build-and-publish-branch
Discard

YOLO-NAS License

These model weights or any components comprising the model and the associated documentation (the "Software") is licensed to you by Deci.AI, Inc. ("Deci") under the following terms: © 2023 – Deci.AI, Inc.

Subject to your full compliance with all of the terms herein, Deci hereby grants you a non-exclusive, revocable, non-sublicensable, non-transferable worldwide and limited right and license to use the Software. If you are using the Deci platform for model optimization, your use of the Software is subject to the Terms of Use available here (the "Terms of Use").

You shall not, without Deci's prior written consent: (i) resell, lease, sublicense or distribute the Software to any person; (ii) use the Software to provide third parties with managed services or provide remote access to the Software to any person or compete with Deci in any way; (iii) represent that you possess any proprietary interest in the Software; (iv) directly or indirectly, take any action to contest Deci's intellectual property rights or infringe them in any way; (V) reverse-engineer, decompile, disassemble, alter, enhance, improve, add to, delete from, or otherwise modify, or derive (or attempt to derive) the technology or source code underlying any part of the Software; (vi) use the Software (or any part thereof) in any illegal, indecent, misleading, harmful, abusive, harassing and/or disparaging manner or for any such purposes. Except as provided under the terms of any separate agreement between you and Deci, including the Terms of Use to the extent applicable, you may not use the Software for any commercial use, including in connection with any models used in a production environment.

DECI PROVIDES THE SOFTWARE "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS OF THE SOFTWARE BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Discard
@@ -44,12 +44,20 @@ ________________________________________________________________________________
 
 
 ### Ready to deploy pre-trained SOTA models
+
+YOLO-NAS architecture is out! The new YOLO-NAS delivers state-of-the-art performance with the unparalleled accuracy-speed performance, outperforming other models such as YOLOv5, YOLOv6, YOLOv7 and YOLOv8.
+Check it out here: [YOLO-NAS](YOLONAS.md).
+
+<div align="center">
+<img src="./documentation/source/images/yolo_nas_frontier.png" width="800px">
+</div>
+
 ```python
 # Load model with pretrained weights
 from super_gradients.training import models
 from super_gradients.common.object_names import Models
 
-model = models.get(Models.YOLOX_S, pretrained_weights="coco")
+model = models.get(Models.YOLO_NAS_M, pretrained_weights="coco")
 ```
 #### All Computer Vision Models - Pretrained Checkpoints can be found in the [Model Zoo](http://bit.ly/41dkt89)
 
@@ -89,17 +97,17 @@ All SuperGradients models’ are production ready in the sense that they are com
 from super_gradients.training import models
 from super_gradients.common.object_names import Models
 
-model = models.get(Models.YOLOX_S, pretrained_weights="coco")
+model = models.get(Models.YOLO_NAS_M, pretrained_weights="coco")
 
 # Prepare model for conversion
-# Input size is in format of [Batch x Channels x Width x Height] where 640 is the standart COCO dataset dimensions
+# Input size is in format of [Batch x Channels x Width x Height] where 640 is the standard COCO dataset dimensions
 model.eval()
 model.prep_model_for_conversion(input_size=[1, 3, 640, 640])
     
 # Create dummy_input
 
 # Convert model to onnx
-torch.onnx.export(model, dummy_input,  "yolox_s.onnx")
+torch.onnx.export(model, dummy_input,  "yolo_nas_m.onnx")
 ```
 More information on how to take your model to production can be found in [Getting Started](#getting-started) notebooks
 
Discard

YOLO-NAS

A Next-Generation, Object Detection Foundational Model generated by Deci’s Neural Architecture Search Technology

Deci is thrilled to announce the release of a new object detection model, YOLO-NAS - a game-changer in the world of object detection, providing superior real-time object detection capabilities and production-ready performance. Deci's mission is to provide AI teams with tools to remove development barriers and attain efficient inference performance more quickly.

YOLO-NAS

The new YOLO-NAS delivers state-of-the-art (SOTA) performance with the unparalleled accuracy-speed performance, outperforming other models such as YOLOv5, YOLOv6, YOLOv7 and YOLOv8.

Deci's proprietary Neural Architecture Search technology, AutoNAC™, generated the YOLO-NAS model. The AutoNAC™ engine lets you input any task, data characteristics (access to data is not required), inference environment and performance targets, and then guides you to find the optimal architecture that delivers the best balance between accuracy and inference speed for your specific application. In addition to being data and hardware aware, the AutoNAC engine considers other components in the inference stack, including compilers and quantization.

In terms of pure numbers, YOLO-NAS is ~0.5 mAP point more accurate and 10-20% faster than equivalent variants of YOLOv8 and YOLOv7.

Model mAP Latency (ms)
YOLO-NAS S 47.5 3.21
YOLO-NAS M 51.55 5.85
YOLO-NAS L 52.22 7.87
YOLO-NAS S INT-8 47.03 2.36
YOLO-NAS M INT-8 51.0 3.78
YOLO-NAS L INT-8 52.1 4.78

mAP numbers in table reported for Coco 2017 Val dataset and latency benchmarked for 640x640 images on Nvidia T4 GPU.

YOLO-NAS's architecture employs quantization-aware blocks and selective quantization for optimized performance. When converted to its INT8 quantized version, YOLO-NAS experiences a smaller precision drop (0.51, 0.65, and 0.45 points of mAP for S, M, and L variants) compared to other models that lose 1-2 mAP points during quantization. These techniques culminate in innovative architecture with superior object detection capabilities and top-notch performance.

Quickstart

import super_gradients

yolo_nas = super_gradients.training.models.get("yolo_nas_l", pretrained_weights="coco").cuda()
yolo_nas.predict("https://deci-pretrained-models.s3.amazonaws.com/sample_images/beatles-abbeyroad.jpg").show()

YOLO-NAS Predict Demo

Recipes

We provide fine-tuning recipies for Roboflow-100 datasets.

Great fine-tuning potential

We demonstrate great performance of YOLO-NAS on downstream tasks. When fine-tuning on Roboflow-100 our YOLO-NAS model achieves higher mAP than our nearest competitors:

YOLO-NAS-RF-100

Additional resources

Documentation: YOLO-NAS Quickstart
Documentation: YOLO-NAS Quantization-Aware training and post-training Quantization
Inference Notebook
Fine-Tuning Notebook

LICENSE

The YOLO-NAS model is available under an open-source license with pre-trained weights available for non-commercial use on SuperGradients, Deci's PyTorch-based, open-source, computer vision training library. With SuperGradients, users can train models from scratch or fine-tune existing ones, leveraging advanced built-in training techniques like Distributed Data Parallel, Exponential Moving Average, Automatic mixed precision, and Quantization Aware Training.

License file is available here: YOLO-NAS WEIGHTS LICENSE

Discard

YOLO-NAS Quickstart

Deci’s leveraged its proprietary Neural Architecture Search engine (AutoNAC) to generate YOLO-NAS - a new object detection architecture that delivers the world’s best accuracy-latency performance.

The YOLO-SG model incorporates quantization-aware RepVGG blocks to ensure compatibility with post-training quantization, making it very flexible and usable for different hardware configurations.

In this tutorial, we will go over the basic functionality of the YOLO-NAS model.

Instantiate a YOLO-NAS Model

from super_gradients.training import models
from super_gradients.common.object_names import Models

net = models.get(Models.YoloNAS_S, pretrained_weights="coco")

Predict

prediction = net.predict("https://www.aljazeera.com/wp-content/uploads/2022/12/2022-12-03T205130Z_851430040_UP1EIC31LXSAZ_RTRMADP_3_SOCCER-WORLDCUP-ARG-AUS-REPORT.jpg?w=770&resize=770%2C436&quality=80")
prediction.show()

Export to ONNX

models.convert_to_onnx(model=net, input_shape=(3,640,640), out_path="yolo_nas_s.onnx")

Train on RF100

Follow the setup instructions for RF100:

        - Follow the official instructions to download Roboflow100: https://github.com/roboflow/roboflow-100-benchmark?ref=roboflow-blog
            //!\\ To use this dataset, you must download the "coco" format, NOT the yolov5.

        - Your dataset should look like this:
            rf100
            ├── 4-fold-defect
            │      ├─ train
            │      │    ├─ 000000000001.jpg
            │      │    ├─ ...
            │      │    └─ _annotations.coco.json
            │      ├─ valid
            │      │    └─ ...
            │      └─ test
            │           └─ ...
            ├── abdomen-mri
            │      └─ ...
            └── ...

        - Install CoCo API: https://github.com/pdollar/coco/tree/master/PythonAPI

We will use the roboflow_yolo_nas_sconfiguration to train the small variant of our YOLO-NAS, YOLO-NAS-S.

To launch training on one of the RF100 datasets, we pass it through the dataset_name argument:

python -m super_gradients.train_from_recipe --config-name=roboflow_yolo_nas_s  dataset_name=<DATASET_NAME> dataset_params.data_dir=<PATH_TO_RF100_ROOT> ckpt_root_dir=<YOUR_CHECKPOINTS_ROOT_DIRECTORY>

Replace <DATASET_NAME> with any of the RF100 datasets that you wish to train on.

Discard
Discard
Discard
Discard
Discard
Discard
Discard

PTQ and QAT with YOLO-NAS

In this tutorial, we will guide you step by step on how to prepare our YOLO-NAS for production! We will leverage YOLO-NAS architecture which includes quantization-friendly blocks, and train a YOLO-NAS model on Roboflow's Soccer Player Detection Dataset in a way that would maximize our throughput without compromising on the model's accuracy.

The steps will be:

  1. Training from scratch on one of the downstream datasets - these will play the role of the user's dataset (i.e., the one in which the model will need to be trained for the user's task)
  2. Performing post-training quantization and quantization-aware training

Pre-requisites:

Now, let's get to it.

Step 0: Installations and Dataset Setup

Follow the setup instructions for RF100:

        - Follow the official instructions to download Roboflow100: https://github.com/roboflow/roboflow-100-benchmark?ref=roboflow-blog
            //!\\ To use this dataset, you must download the "coco" format, NOT the yolov5.

        - Your dataset should look like this:
            rf100
            ├── 4-fold-defect
            │      ├─ train
            │      │    ├─ 000000000001.jpg
            │      │    ├─ ...
            │      │    └─ _annotations.coco.json
            │      ├─ valid
            │      │    └─ ...
            │      └─ test
            │           └─ ...
            ├── abdomen-mri
            │      └─ ...
            └── ...

        - Install CoCo API: https://github.com/pdollar/coco/tree/master/PythonAPI

Install the latest version of SG:

pip install super-gradients

Install torch + PyTorch-quantization (note that later versions should be compatible as well and that you should essentially follow torch installation according to https://pytorch.org/get-started/locally/)

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 &> /dev/null
pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com &> /dev/null

Launch Training (non-QA)

Although this might come as a surprise - the name quantization-aware training needs to be more accurate and be performed on a trained checkpoint rather than from scratch. So in practice, we need to train our model on our dataset fully, then after we perform calibration, we fine-tune our model once again, which will be our final step. As we discuss in our Training with configuration files, we clone the SG repo, then use the repo's configuration files in our training examples. We will use the src/super_gradients/recipes/roboflow_yolo_nas_s.yamlconfiguration to train the small variant of our DeciModel, DeciModel S.

So we navigate to our train_from_recipe script:

cd <YOUR-LOCAL-PATH>/super_gradients/src/super_gradients/examples/train_from_recipe_example

Then to avoid collisions between our cloned and installed SG:

export PYTHONPATH=$PYTHONPATH:<YOUR-LOCAL-PATH>/super_gradients/

To launch training on one of the RF100 datasets, we pass it through the dataset_name argument:

python -m train_from_recipe --config-name=roboflow_yolo_nas_s  dataset_name=soccer-players-5fuqs dataset_params.data_dir=<PATH_TO_RF100_ROOT> ckpt_root_dir=<YOUR_CHECKPOINTS_ROOT_DIRECTORY> experiment_name=yolo_nas_s_soccer_players

...

Train epoch 99: 100%|██████████| 32/32 [00:23<00:00,  1.35it/s, PPYoloELoss/loss=0.853, PPYoloELoss/loss_cls=0.417, PPYoloELoss/loss_dfl=0.56, PPYoloELoss/loss_iou=0.0621, gpu_mem=11.7]
Validation epoch 99: 100%|██████████| 3/3 [00:00<00:00,  5.49it/s]
===========================================================
SUMMARY OF EPOCH 99
├── Training
│   ├── Ppyoloeloss/loss = 0.8527
│   │   ├── Best until now = 0.8515 (↗ 0.0012)
│   │   └── Epoch N-1      = 0.8515 (↗ 0.0012)
│   ├── Ppyoloeloss/loss_cls = 0.4174
│   │   ├── Best until now = 0.4178 (↘ -0.0004)
│   │   └── Epoch N-1      = 0.4178 (↘ -0.0004)
│   ├── Ppyoloeloss/loss_dfl = 0.5602
│   │   ├── Best until now = 0.5573 (↗ 0.0029)
│   │   └── Epoch N-1      = 0.5573 (↗ 0.0029)
│   └── Ppyoloeloss/loss_iou = 0.0621
│       ├── Best until now = 0.062  (↗ 0.0)
│       └── Epoch N-1      = 0.062  (↗ 0.0)
└── Validation
    ├── F1@0.50 = 0.779
    │   ├── Best until now = 0.8185 (↘ -0.0395)
    │   └── Epoch N-1      = 0.796  (↘ -0.017)
    ├── Map@0.50 = 0.9601
    │   ├── Best until now = 0.967  (↘ -0.0069)
    │   └── Epoch N-1      = 0.957  (↗ 0.0031)
    ├── Ppyoloeloss/loss = 1.4472
    │   ├── Best until now = 1.3971 (↗ 0.0501)
    │   └── Epoch N-1      = 1.4421 (↗ 0.0051)
    ├── Ppyoloeloss/loss_cls = 0.5981
    │   ├── Best until now = 0.527  (↗ 0.0711)
    │   └── Epoch N-1      = 0.5986 (↘ -0.0005)
    ├── Ppyoloeloss/loss_dfl = 0.8216
    │   ├── Best until now = 0.7849 (↗ 0.0367)
    │   └── Epoch N-1      = 0.8202 (↗ 0.0014)
    ├── Ppyoloeloss/loss_iou = 0.1753
    │   ├── Best until now = 0.1684 (↗ 0.007)
    │   └── Epoch N-1      = 0.1734 (↗ 0.002)
    ├── Precision@0.50 = 0.6758
    │   ├── Best until now = 0.7254 (↘ -0.0495)
    │   └── Epoch N-1      = 0.6931 (↘ -0.0172)
    └── Recall@0.50 = 0.9567
        ├── Best until now = 0.9872 (↘ -0.0304)
        └── Epoch N-1      = 0.9567 (= 0.0)

===========================================================
[2023-03-30 14:09:47] INFO - sg_trainer.py - RUNNING ADDITIONAL TEST ON THE AVERAGED MODEL...
Validation epoch 100: 100%|██████████| 3/3 [00:00<00:00,  5.45it/s]
===========================================================
SUMMARY OF EPOCH 100
├── Training
│   ├── Ppyoloeloss/loss = 0.8527
│   │   ├── Best until now = 0.8515 (↗ 0.0012)
│   │   └── Epoch N-1      = 0.8515 (↗ 0.0012)
│   ├── Ppyoloeloss/loss_cls = 0.4174
│   │   ├── Best until now = 0.4178 (↘ -0.0004)
│   │   └── Epoch N-1      = 0.4178 (↘ -0.0004)
│   ├── Ppyoloeloss/loss_dfl = 0.5602
│   │   ├── Best until now = 0.5573 (↗ 0.0029)
│   │   └── Epoch N-1      = 0.5573 (↗ 0.0029)
│   └── Ppyoloeloss/loss_iou = 0.0621
│       ├── Best until now = 0.062  (↗ 0.0)
│       └── Epoch N-1      = 0.062  (↗ 0.0)
└── Validation
    ├── F1@0.50 = 0.7824
    │   ├── Best until now = 0.8185 (↘ -0.0361)
    │   └── Epoch N-1      = 0.779  (↗ 0.0034)
    ├── Map@0.50 = 0.9635
    │   ├── Best until now = 0.967  (↘ -0.0036)
    │   └── Epoch N-1      = 0.9601 (↗ 0.0033)
    ├── Ppyoloeloss/loss = 1.432
    │   ├── Best until now = 1.3971 (↗ 0.0349)
    │   └── Epoch N-1      = 1.4472 (↘ -0.0152)
    ├── Ppyoloeloss/loss_cls = 0.588
    │   ├── Best until now = 0.527  (↗ 0.061)
    │   └── Epoch N-1      = 0.5981 (↘ -0.0101)
    ├── Ppyoloeloss/loss_dfl = 0.8191
    │   ├── Best until now = 0.7849 (↗ 0.0343)
    │   └── Epoch N-1      = 0.8216 (↘ -0.0025)
    ├── Ppyoloeloss/loss_iou = 0.1738
    │   ├── Best until now = 0.1684 (↗ 0.0054)
    │   └── Epoch N-1      = 0.1753 (↘ -0.0015)
    ├── Precision@0.50 = 0.6769
    │   ├── Best until now = 0.7254 (↘ -0.0485)
    │   └── Epoch N-1      = 0.6758 (↗ 0.0011)
    └── Recall@0.50 = 0.9567
        ├── Best until now = 0.9872 (↘ -0.0304)
        └── Epoch N-1      = 0.9567 (= 0.0)

And so our best checkpoint resides in <YOUR_CHECKPOINTS_ROOT_DIRECTORY>/yolo_nas_s_soccer_players/ckpt_best.pth reaches 0.967 mAP!

Let's visualize some results:

from super_gradients.common.object_names import Models
from super_gradients.training import models

model = models.get(Models.YOLO_NAS_S,
                   checkpoint_path=<YOUR_CHECKPOINTS_ROOT_DIRECTORY>/yolo_nas_s_soccer_players/ckpt_best.pth>,
                   num_classes=4)
predictions = model.predict("messi_penalty.mp4")
predictions.show(show_confidence=False)

QAT and PTQ

Now, we will take our checkpoint from our previous section and perform post-training quantization, then quantization-aware training. To do so, we will need to launch training with our qat_from_recipe example script, which simplifies taking any existing training recipe and making it a quantization-aware one with the help of some of our recommended practices. So this time, we navigate to the qat_from_recipe example directory:

cd <YOUR-LOCAL-PATH>/super_gradients/src/super_gradients/examples/qat_from_recipe_example

Before we launch, let's see how we can easily create a configuration from our roboflow_yolo_nas_s config to get the most out of QAT and PTQ. We added a new config that inherits from our previous one, called roboflow_yolo_nas_s_qat.yaml. Let's peek at it:


defaults:
  - roboflow_yolo_nas_s
  - quantization_params: default_quantization_params
  - _self_

checkpoint_params:
  checkpoint_path: ???
  strict_load: no_key_matching

experiment_name: soccer_players_qat_yolo_nas_s

pre_launch_callbacks_list:
    - QATRecipeModificationCallback:
        batch_size_divisor: 2
        max_epochs_divisor: 10
        lr_decay_factor: 0.01
        warmup_epochs_divisor: 10
        cosine_final_lr_ratio: 0.01
        disable_phase_callbacks: True
        disable_augmentations: False

Let's break it down:

  • We inherit from our original non-QA recipe

  • We set quantization_params to the default ones. Reminder - this is where QAT and PTQ hyper-parameters are defined.

  • We set our checkpoint_params.checkpoint_path to ??? so that passing a checkpoint is required. We will override this value when launching from the command line.

  • We add a QATRecipeModificationCallback to our pre_launch_callbacks_list: This callback accepts the entire cfg: DictConfig and manipulates it right before we start the training. This allows us to adapt any non-QA recipe to a QA one quickly. Here we will:

    • Use half the batch size of the original recipe.
    • Use 10 percent of the number of the epochs (and warmup epochs).
    • Use 1 percent of the original learning rate.
    • Set the final learning rate ratio of the cosine scheduling to 0.01
    • Disable augmentations and the phase_callbacks.

Now we can launch PTQ and QAT from the command line:

python -m qat_from_recipe --config-name=roboflow_yolo_nas_s_qat experiment_name=soccer_players_qat_yolo_nas_s dataset_name=soccer-players-5fuqs dataset_params.data_dir=<PATH_TO_RF100_ROOT> checkpoint_params.checkpoint_path=<YOUR_CHECKPOINTS_ROOT_DIRECTORY>/yolo_nas_s_soccer_players/ckpt_best.pth ckpt_ckpt_root_dir=<YOUR_CHECKPOINTS_ROOT_DIRECTORY>
...

[2023-04-02 11:37:56,848][super_gradients.training.pre_launch_callbacks.pre_launch_callbacks][INFO] - Modifying recipe to suit QAT rules of thumb. Remove QATRecipeModificationCallback to disable.
[2023-04-02 11:37:56,858][super_gradients.training.pre_launch_callbacks.pre_launch_callbacks][WARNING] - New number of epochs: 10
[2023-04-02 11:37:56,858][super_gradients.training.pre_launch_callbacks.pre_launch_callbacks][WARNING] - New learning rate: 5e-06
[2023-04-02 11:37:56,858][super_gradients.training.pre_launch_callbacks.pre_launch_callbacks][WARNING] - New weight decay: 1.0000000000000002e-06
[2023-04-02 11:37:56,858][super_gradients.training.pre_launch_callbacks.pre_launch_callbacks][WARNING] - EMA will be disabled for QAT run.
[2023-04-02 11:37:56,859][super_gradients.training.pre_launch_callbacks.pre_launch_callbacks][WARNING] - SyncBatchNorm will be disabled for QAT run.
[2023-04-02 11:37:56,859][super_gradients.training.pre_launch_callbacks.pre_launch_callbacks][WARNING] - Recipe requests multi_gpu=False and num_gpus=1. Changing to multi_gpu=OFF and num_gpus=1
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:32<00:00,  1.01s/it]
[2023-04-02 11:38:34,316][super_gradients.training.qat_trainer.qat_trainer][INFO] - Validating PTQ model...

  0%|          | 0/3 [00:00<?, ?it/s]
Test:   0%|          | 0/3 [00:00<?, ?it/s]
Test:  33%|███▎      | 1/3 [00:00<00:00,  2.87it/s]
Test:  67%|██████▋   | 2/3 [00:00<00:00,  2.90it/s]
Test: 100%|██████████| 3/3 [00:00<00:00,  3.86it/s]
[2023-04-02 11:38:35,106][super_gradients.training.qat_trainer.qat_trainer][INFO] - PTQ Model Validation Results
   - Precision@0.50: 0.6727069020271301
   - Recall@0.50: 0.95766681432724
   - mAP@0.50  : 0.9465919137001038
   - F1@0.50   : 0.7861716747283936

Observe that for PTQ, our model's mAP decreased from 0.967 to 0.9466. After PTQ, QAT is performed automatically:


[2023-04-02 11:38:47] INFO - sg_trainer.py - Started training for 10 epochs (0/9)

Train epoch 0: 100%|██████████| 32/32 [00:26<00:00,  1.21it/s, PPYoloELoss/loss=0.909, PPYoloELoss/loss_cls=0.444, PPYoloELoss/loss_dfl=0.57, PPYoloELoss/loss_iou=0.0721, gpu_mem=10.1]
Validation epoch 0: 100%|██████████| 3/3 [00:00<00:00,  3.75it/s]
===========================================================
SUMMARY OF EPOCH 0
├── Training
│   ├── Ppyoloeloss/loss = 0.9088
│   ├── Ppyoloeloss/loss_cls = 0.4436
│   ├── Ppyoloeloss/loss_dfl = 0.5696
│   └── Ppyoloeloss/loss_iou = 0.0721
└── Validation
    ├── F1@0.50 = 0.7885
    ├── Map@0.50 = 0.9556
    ├── Ppyoloeloss/loss = 1.4303
    ├── Ppyoloeloss/loss_cls = 0.5847
    ├── Ppyoloeloss/loss_dfl = 0.8186
    ├── Ppyoloeloss/loss_iou = 0.1745
    ├── Precision@0.50 = 0.671
    └── Recall@0.50 = 0.9734

===========================================================
[2023-04-02 11:39:14] INFO - sg_trainer.py - Best checkpoint overriden: validation mAP@0.50: 0.9556358456611633
Train epoch 1: 100%|██████████| 32/32 [00:26<00:00,  1.22it/s, PPYoloELoss/loss=0.91, PPYoloELoss/loss_cls=0.445, PPYoloELoss/loss_dfl=0.574, PPYoloELoss/loss_iou=0.0712, gpu_mem=10.1]
Validation epoch 1: 100%|██████████| 3/3 [00:00<00:00,  3.88it/s]
===========================================================
SUMMARY OF EPOCH 1
├── Training
│   ├── Ppyoloeloss/loss = 0.9097
│   │   ├── Best until now = 0.9088 (↗ 0.001)
│   │   └── Epoch N-1      = 0.9088 (↗ 0.001)
│   ├── Ppyoloeloss/loss_cls = 0.4448
│   │   ├── Best until now = 0.4436 (↗ 0.0011)
│   │   └── Epoch N-1      = 0.4436 (↗ 0.0011)
│   ├── Ppyoloeloss/loss_dfl = 0.5739
│   │   ├── Best until now = 0.5696 (↗ 0.0044)
│   │   └── Epoch N-1      = 0.5696 (↗ 0.0044)
│   └── Ppyoloeloss/loss_iou = 0.0712
│       ├── Best until now = 0.0721 (↘ -0.0009)
│       └── Epoch N-1      = 0.0721 (↘ -0.0009)
└── Validation
    ├── F1@0.50 = 0.7537
    │   ├── Best until now = 0.7885 (↘ -0.0348)
    │   └── Epoch N-1      = 0.7885 (↘ -0.0348)
    ├── Map@0.50 = 0.9581
    │   ├── Best until now = 0.9556 (↗ 0.0025)
    │   └── Epoch N-1      = 0.9556 (↗ 0.0025)
    ├── Ppyoloeloss/loss = 1.4312
    │   ├── Best until now = 1.4303 (↗ 0.0009)
    │   └── Epoch N-1      = 1.4303 (↗ 0.0009)
    ├── Ppyoloeloss/loss_cls = 0.5881
    │   ├── Best until now = 0.5847 (↗ 0.0034)
    │   └── Epoch N-1      = 0.5847 (↗ 0.0034)
    ├── Ppyoloeloss/loss_dfl = 0.8166
    │   ├── Best until now = 0.8186 (↘ -0.002)
    │   └── Epoch N-1      = 0.8186 (↘ -0.002)
    ├── Ppyoloeloss/loss_iou = 0.1739
    │   ├── Best until now = 0.1745 (↘ -0.0006)
    │   └── Epoch N-1      = 0.1745 (↘ -0.0006)
    ├── Precision@0.50 = 0.6262
    │   ├── Best until now = 0.671  (↘ -0.0448)
    │   └── Epoch N-1      = 0.671  (↘ -0.0448)
    └── Recall@0.50 = 0.9734
        ├── Best until now = 0.9734 (= 0.0)
        └── Epoch N-1      = 0.9734 (= 0.0)

===========================================================
...
...
Validation epoch 10: 100%|██████████| 3/3 [00:00<00:00,  4.07it/s]
===========================================================
SUMMARY OF EPOCH 10
├── Training
│   ├── Ppyoloeloss/loss = 0.8901
│   │   ├── Best until now = 0.889  (↗ 0.0011)
│   │   └── Epoch N-1      = 0.8957 (↘ -0.0056)
│   ├── Ppyoloeloss/loss_cls = 0.4365
│   │   ├── Best until now = 0.4359 (↗ 0.0005)
│   │   └── Epoch N-1      = 0.4384 (↘ -0.002)
│   ├── Ppyoloeloss/loss_dfl = 0.5677
│   │   ├── Best until now = 0.5665 (↗ 0.0012)
│   │   └── Epoch N-1      = 0.5702 (↘ -0.0025)
│   └── Ppyoloeloss/loss_iou = 0.0679
│       ├── Best until now = 0.0672 (↗ 0.0007)
│       └── Epoch N-1      = 0.0689 (↘ -0.001)
└── Validation
    ├── F1@0.50 = 0.7373
    │   ├── Best until now = 0.7885 (↘ -0.0512)
    │   └── Epoch N-1      = 0.721  (↗ 0.0164)
    ├── Map@0.50 = 0.968
    │   ├── Best until now = 0.9672 (↗ 0.0007)
    │   └── Epoch N-1      = 0.9517 (↗ 0.0163)
    ├── Ppyoloeloss/loss = 1.4326
    │   ├── Best until now = 1.4303 (↗ 0.0023)
    │   └── Epoch N-1      = 1.4322 (↗ 0.0004)
    ├── Ppyoloeloss/loss_cls = 0.5887
    │   ├── Best until now = 0.5847 (↗ 0.004)
    │   └── Epoch N-1      = 0.5889 (↘ -0.0002)
    ├── Ppyoloeloss/loss_dfl = 0.8164
    │   ├── Best until now = 0.8154 (↗ 0.001)
    │   └── Epoch N-1      = 0.8158 (↗ 0.0006)
    ├── Ppyoloeloss/loss_iou = 0.1743
    │   ├── Best until now = 0.1737 (↗ 0.0006)
    │   └── Epoch N-1      = 0.1742 (↗ 1e-04)
    ├── Precision@0.50 = 0.6052
    │   ├── Best until now = 0.671  (↘ -0.0658)
    │   └── Epoch N-1      = 0.5953 (↗ 0.01)
    └── Recall@0.50 = 0.9853
        ├── Best until now = 0.9853 (= 0.0)
        └── Epoch N-1      = 0.9734 (↗ 0.0119)

We not only observed no decline in the accuracy of our quantized model, but we also gained an improvement of 0.08 mAP! The QAT model is available in our checkpoints directory, already converted to .onnx format under <YOUR_CHECKPOINTS_ROOT_DIRECTORY>/soccer_players_qat_yolo_nas_s/soccer_players_qat_yolo_nas_s_16x3x640x640_qat.onnx, ready to be converted to converted and deployed to int8 using TRT.

Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  1. ARG DOCKER_IMAGE_TAG=11.3.1-devel-ubuntu20.04
  2. FROM nvidia/cuda:${DOCKER_IMAGE_TAG}
  3. LABEL maintainer "DECI.AI <services@deci.ai>"
  4. ARG DEBIAN_FRONTEND=noninteractive
  5. RUN mkdir /SG
  6. WORKDIR /SG
  7. RUN apt-get update && apt-get install -y python3-pip python-is-python3 pip libgl1 libglib2.0-0 git python3-distutils python3-typing-extensions \
  8. && rm -rf /var/lib/apt/lists/*
  9. COPY . .
  10. RUN pip install . --no-cache-dir && pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir
  11. RUN pip uninstall -y typing_extensions && pip install wandb --no-cache-dir
Discard
1
2
3
4
5
6
7
8
9
10
  1. ARG BASE_TAG
  2. ARG BASE_DOCKER_REPO
  3. FROM ${BASE_DOCKER_REPO:-307629990626.dkr.ecr.us-east-1.amazonaws.com/deci/super-gradients}:${BASE_TAG:-latest}
  4. LABEL maintainer "DECI.AI <services@deci.ai>"
  5. ARG DEBIAN_FRONTEND=noninteractive
  6. WORKDIR /
  7. RUN rm -rf /SG && mkdir /SG
  8. WORKDIR /SG
  9. COPY . .
  10. RUN pip install . --no-cache-dir
Discard
@@ -304,6 +304,9 @@ class Models:
     DEKR_W32_NO_DC = "dekr_w32_no_dc"
     POSE_PP_YOLO_L = "pose_ppyolo_l"
     POSE_DDRNET_39 = "pose_ddrnet39"
+    YOLO_NAS_S = "yolo_nas_s"
+    YOLO_NAS_M = "yolo_nas_m"
+    YOLO_NAS_L = "yolo_nas_l"
 
 
 class ConcatenatedTensorFormats:
@@ -326,8 +329,8 @@ class Dataloaders:
     COCO2017_VAL = "coco2017_val"
     COCO2017_TRAIN_YOLOX = "coco2017_train_yolox"
     COCO2017_VAL_YOLOX = "coco2017_val_yolox"
-    COCO2017_TRAIN_DECIYOLO = "coco2017_train_deci_yolo"
-    COCO2017_VAL_DECIYOLO = "coco2017_val_deci_yolo"
+    COCO2017_TRAIN_YOLO_NAS = "coco2017_train_yolo_nas"
+    COCO2017_VAL_YOLO_NAS = "coco2017_val_yolo_nas"
     COCO2017_TRAIN_PPYOLOE = "coco2017_train_ppyoloe"
     COCO2017_VAL_PPYOLOE = "coco2017_val_ppyoloe"
     COCO2017_TRAIN_SSD_LITE_MOBILENET_V2 = "coco2017_train_ssd_lite_mobilenet_v2"
Discard
@@ -1,8 +1,8 @@
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.PP_YOLOE_S, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
 
 IMAGES = [
     "../../../../documentation/source/images/examples/countryside.jpg",
Discard
@@ -1,8 +1,8 @@
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.YOLOX_N, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YoloNAS_L, pretrained_weights="coco")
 
 image_folder_path = "../../../../documentation/source/images/examples"
 
Discard
@@ -2,8 +2,8 @@ import torch
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.YOLOX_N, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
 
 # We want to use cuda if available to speed up inference.
 model = model.to("cuda" if torch.cuda.is_available() else "cpu")
Discard
@@ -3,8 +3,8 @@ import torch
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.YOLOX_N, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
 
 # We want to use cuda if available to speed up inference.
 model = model.to("cuda" if torch.cuda.is_available() else "cpu")
Discard
@@ -1,3 +1,3 @@
-from .module_interfaces import HasPredict, HasPreprocessingParams
+from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
 
-__all__ = ["HasPredict", "HasPreprocessingParams"]
+__all__ = ["HasPredict", "HasPreprocessingParams", "SupportsReplaceNumClasses"]
Discard
@@ -1,3 +1,6 @@
+from typing import Callable
+
+from torch import nn
 from typing_extensions import Protocol, runtime_checkable
 
 
@@ -31,3 +34,28 @@ class HasPredict(Protocol):
 
     def predict_webcam(self, *args, **kwargs):
         ...
+
+
+@runtime_checkable
+class SupportsReplaceNumClasses(Protocol):
+    """
+    Protocol interface for modules that support replacing the number of classes.
+    Derived classes should implement the `replace_num_classes` method.
+
+    This interface class serves a purpose of explicitly indicating whether a class supports optimized head replacement:
+
+    >>> class PredictionHead(nn.Module, SupportsReplaceNumClasses):
+    >>>    def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module] = None):
+    >>>       ...
+    """
+
+    def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]):
+        """
+        Replace the number of classes in the module.
+
+        :param num_classes: New number of classes.
+        :param compute_new_weights_fn: (callable) An optional function that computes the new weights for the new classes.
+            It takes existing nn.Module and returns a new one.
+        :return: None
+        """
+        ...
Discard
@@ -16,7 +16,23 @@ from super_gradients.modules.skip_connections import (
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.registry.registry import ALL_DETECTION_MODULES
 
+from super_gradients.modules.base_modules import BaseDetectionModule
+from super_gradients.modules.detection_modules import (
+    PANNeck,
+    NHeads,
+    MultiOutputBackbone,
+    NStageBackbone,
+    MobileNetV1Backbone,
+    MobileNetV2Backbone,
+    SSDNeck,
+    SSDInvertedResidualNeck,
+    SSDBottleneckNeck,
+    SSDHead,
+)
+from super_gradients.module_interfaces import SupportsReplaceNumClasses
+
 __all__ = [
+    "BaseDetectionModule",
     "ALL_DETECTION_MODULES",
     "PixelShuffle",
     "AntiAliasDownsample",
@@ -33,6 +49,17 @@ __all__ = [
     "BackboneInternalSkipConnection",
     "HeadInternalSkipConnection",
     "LightweightDEKRHead",
+    "PANNeck",
+    "NHeads",
+    "MultiOutputBackbone",
+    "NStageBackbone",
+    "MobileNetV1Backbone",
+    "MobileNetV2Backbone",
+    "SSDNeck",
+    "SSDInvertedResidualNeck",
+    "SSDBottleneckNeck",
+    "SSDHead",
+    "SupportsReplaceNumClasses",
 ]
 
 logger = get_logger(__name__)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  1. from abc import abstractmethod, ABC
  2. from typing import Union, List
  3. from torch import nn
  4. __all__ = ["BaseDetectionModule"]
  5. class BaseDetectionModule(nn.Module, ABC):
  6. """
  7. An interface for a module that is easy to integrate into a model with complex connections
  8. """
  9. def __init__(self, in_channels: Union[List[int], int], **kwargs):
  10. """
  11. :param in_channels: defines channels of tensor(s) that will be accepted by a module in forward
  12. """
  13. super().__init__()
  14. self.in_channels = in_channels
  15. @property
  16. @abstractmethod
  17. def out_channels(self) -> Union[List[int], int]:
  18. """
  19. :return: channels of tensor(s) that will be returned by a module in forward
  20. """
  21. raise NotImplementedError()
Discard
@@ -1,37 +1,30 @@
+from abc import ABC, abstractmethod
 from typing import Union, List
-from abc import abstractmethod, ABC
 
 import torch
-from torch import nn
-from omegaconf.listconfig import ListConfig
 from omegaconf import DictConfig
-
+from omegaconf.listconfig import ListConfig
 from super_gradients.common.registry.registry import register_detection_module
+from super_gradients.modules.base_modules import BaseDetectionModule
+from super_gradients.modules.multi_output_modules import MultiOutputModule
+from super_gradients.training.models import MobileNet, MobileNetV2
 from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
 from super_gradients.training.utils.utils import HpmStruct
-from super_gradients.training.models import MobileNet, MobileNetV2
-from super_gradients.modules.multi_output_modules import MultiOutputModule
-
-
-class BaseDetectionModule(nn.Module, ABC):
-    """
-    An interface for a module that is easy to integrate into a model with complex connections
-    """
-
-    def __init__(self, in_channels: Union[List[int], int], **kwargs):
-        """
-        :param in_channels: defines channels of tensor(s) that will be accepted by a module in forward
-        """
-        super().__init__()
-        self.in_channels = in_channels
+from torch import nn
 
-    @property
-    @abstractmethod
-    def out_channels(self) -> Union[List[int], int]:
-        """
-        :return: channels of tensor(s) that will be returned by a module  in forward
-        """
-        raise NotImplementedError()
+__all__ = [
+    "PANNeck",
+    "NHeads",
+    "MultiOutputBackbone",
+    "NStageBackbone",
+    "MobileNetV1Backbone",
+    "MobileNetV2Backbone",
+    "SSDNeck",
+    "SSDInvertedResidualNeck",
+    "SSDBottleneckNeck",
+    "SSDHead",
+    "BaseDetectionModule",
+]
 
 
 @register_detection_module()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  1. from typing import Union
  2. import torch
  3. from torch import nn
  4. __all__ = ["replace_num_classes_with_random_weights"]
  5. def replace_num_classes_with_random_weights(module: Union[nn.Conv2d, nn.Linear, nn.Module], num_classes: int) -> nn.Module:
  6. """
  7. Replace the number of classes in the module with random weights.
  8. This is useful for replacing the output layer of a detection/classification head.
  9. This implementation support Conv2d and Linear layers.
  10. Returned module will have the same device and dtype as the original module.
  11. Random weights are initialized with the same mean and std as the original weights.
  12. :param module: (nn.Module) Module to replace the number of classes in.
  13. :param num_classes: New number of classes.
  14. :return: nn.Module
  15. """
  16. if isinstance(module, nn.Conv2d):
  17. new_module = nn.Conv2d(
  18. module.in_channels,
  19. num_classes,
  20. kernel_size=module.kernel_size,
  21. stride=module.stride,
  22. padding=module.padding,
  23. dilation=module.dilation,
  24. groups=module.groups,
  25. bias=module.bias is not None,
  26. device=module.weight.device,
  27. dtype=module.weight.dtype,
  28. )
  29. torch.nn.init.normal_(new_module.weight, mean=module.weight.mean().item(), std=module.weight.std(dim=(0, 1, 2, 3)).item())
  30. if module.bias is not None:
  31. torch.nn.init.normal_(new_module.bias, mean=module.bias.mean().item(), std=module.bias.std(dim=0).item())
  32. return new_module
  33. elif isinstance(module, nn.Linear):
  34. new_module = nn.Linear(module.in_features, num_classes, device=module.weight.device, dtype=module.weight.dtype, bias=module.bias is not None)
  35. torch.nn.init.normal_(new_module.weight, mean=module.weight.mean().item(), std=module.weight.std(dim=(0, 1, 2)).item())
  36. if module.bias is not None:
  37. torch.nn.init.normal_(new_module.bias, mean=module.bias.mean().item(), std=module.bias.std(dim=0).item())
  38. return new_module
  39. else:
  40. raise ValueError(f"Module {module} does not support replacing the number of classes")
Discard
@@ -5,7 +5,7 @@ from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from torch import nn, Tensor
 
-from super_gradients.modules.detection_modules import BaseDetectionModule
+from super_gradients.modules.base_modules import BaseDetectionModule
 from super_gradients.common.registry.registry import register_detection_module
 
 
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. backbone:
  2. NStageBackbone:
  3. stem:
  4. YoloNASStem:
  5. out_channels: 48
  6. stages:
  7. - YoloNASStage:
  8. out_channels: 96
  9. num_blocks: 2
  10. activation_type: relu
  11. hidden_channels: 96
  12. concat_intermediates: True
  13. - YoloNASStage:
  14. out_channels: 192
  15. num_blocks: 3
  16. activation_type: relu
  17. hidden_channels: 128
  18. concat_intermediates: True
  19. - YoloNASStage:
  20. out_channels: 384
  21. num_blocks: 5
  22. activation_type: relu
  23. hidden_channels: 256
  24. concat_intermediates: True
  25. - YoloNASStage:
  26. out_channels: 768
  27. num_blocks: 2
  28. activation_type: relu
  29. hidden_channels: 512
  30. concat_intermediates: True
  31. context_module:
  32. SPP:
  33. output_channels: 768
  34. activation_type: relu
  35. k: [5,9,13]
  36. out_layers: [stage1, stage2, stage3, context_module]
  37. neck:
  38. YoloNASPANNeckWithC2:
  39. neck1:
  40. YoloNASUpStage:
  41. out_channels: 192
  42. num_blocks: 4
  43. hidden_channels: 128
  44. width_mult: 1
  45. depth_mult: 1
  46. activation_type: relu
  47. reduce_channels: True
  48. neck2:
  49. YoloNASUpStage:
  50. out_channels: 96
  51. num_blocks: 4
  52. hidden_channels: 128
  53. width_mult: 1
  54. depth_mult: 1
  55. activation_type: relu
  56. reduce_channels: True
  57. neck3:
  58. YoloNASDownStage:
  59. out_channels: 192
  60. num_blocks: 4
  61. hidden_channels: 128
  62. activation_type: relu
  63. width_mult: 1
  64. depth_mult: 1
  65. neck4:
  66. YoloNASDownStage:
  67. out_channels: 384
  68. num_blocks: 4
  69. hidden_channels: 256
  70. activation_type: relu
  71. width_mult: 1
  72. depth_mult: 1
  73. heads:
  74. NDFLHeads:
  75. num_classes: 80
  76. reg_max: 16
  77. heads_list:
  78. - YoloNASDFLHead:
  79. inter_channels: 128
  80. width_mult: 1
  81. first_conv_group_size: 0
  82. stride: 8
  83. - YoloNASDFLHead:
  84. inter_channels: 256
  85. width_mult: 1
  86. first_conv_group_size: 0
  87. stride: 16
  88. - YoloNASDFLHead:
  89. inter_channels: 512
  90. width_mult: 1
  91. first_conv_group_size: 0
  92. stride: 32
  93. bn_eps: 1e-3
  94. bn_momentum: 0.03
  95. inplace_act: True
  96. _convert_: all
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. backbone:
  2. NStageBackbone:
  3. stem:
  4. YoloNASStem:
  5. out_channels: 48
  6. stages:
  7. - YoloNASStage:
  8. out_channels: 96
  9. num_blocks: 2
  10. activation_type: relu
  11. hidden_channels: 64
  12. concat_intermediates: True
  13. - YoloNASStage:
  14. out_channels: 192
  15. num_blocks: 3
  16. activation_type: relu
  17. hidden_channels: 128
  18. concat_intermediates: True
  19. - YoloNASStage:
  20. out_channels: 384
  21. num_blocks: 5
  22. activation_type: relu
  23. hidden_channels: 256
  24. concat_intermediates: True
  25. - YoloNASStage:
  26. out_channels: 768
  27. num_blocks: 2
  28. activation_type: relu
  29. hidden_channels: 384
  30. concat_intermediates: False
  31. context_module:
  32. SPP:
  33. output_channels: 768
  34. activation_type: relu
  35. k: [5,9,13]
  36. out_layers: [stage1, stage2, stage3, context_module]
  37. neck:
  38. YoloNASPANNeckWithC2:
  39. neck1:
  40. YoloNASUpStage:
  41. out_channels: 192
  42. num_blocks: 2
  43. hidden_channels: 192
  44. width_mult: 1
  45. depth_mult: 1
  46. activation_type: relu
  47. reduce_channels: True
  48. neck2:
  49. YoloNASUpStage:
  50. out_channels: 96
  51. num_blocks: 3
  52. hidden_channels: 64
  53. width_mult: 1
  54. depth_mult: 1
  55. activation_type: relu
  56. reduce_channels: True
  57. neck3:
  58. YoloNASDownStage:
  59. out_channels: 192
  60. num_blocks: 2
  61. hidden_channels: 192
  62. activation_type: relu
  63. width_mult: 1
  64. depth_mult: 1
  65. neck4:
  66. YoloNASDownStage:
  67. out_channels: 384
  68. num_blocks: 3
  69. hidden_channels: 256
  70. activation_type: relu
  71. width_mult: 1
  72. depth_mult: 1
  73. heads:
  74. NDFLHeads:
  75. num_classes: 80
  76. reg_max: 16
  77. heads_list:
  78. - YoloNASDFLHead:
  79. inter_channels: 128
  80. width_mult: 0.75
  81. first_conv_group_size: 0
  82. stride: 8
  83. - YoloNASDFLHead:
  84. inter_channels: 256
  85. width_mult: 0.75
  86. first_conv_group_size: 0
  87. stride: 16
  88. - YoloNASDFLHead:
  89. inter_channels: 512
  90. width_mult: 0.75
  91. first_conv_group_size: 0
  92. stride: 32
  93. bn_eps: 1e-3
  94. bn_momentum: 0.03
  95. inplace_act: True
  96. _convert_: all
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. backbone:
  2. NStageBackbone:
  3. stem:
  4. YoloNASStem:
  5. out_channels: 48
  6. stages:
  7. - YoloNASStage:
  8. out_channels: 96
  9. num_blocks: 2
  10. activation_type: relu
  11. hidden_channels: 32
  12. concat_intermediates: False
  13. - YoloNASStage:
  14. out_channels: 192
  15. num_blocks: 3
  16. activation_type: relu
  17. hidden_channels: 64
  18. concat_intermediates: False
  19. - YoloNASStage:
  20. out_channels: 384
  21. num_blocks: 5
  22. activation_type: relu
  23. hidden_channels: 96
  24. concat_intermediates: False
  25. - YoloNASStage:
  26. out_channels: 768
  27. num_blocks: 2
  28. activation_type: relu
  29. hidden_channels: 192
  30. concat_intermediates: False
  31. context_module:
  32. SPP:
  33. output_channels: 768
  34. activation_type: relu
  35. k: [5,9,13]
  36. out_layers: [stage1, stage2, stage3, context_module]
  37. neck:
  38. YoloNASPANNeckWithC2:
  39. neck1:
  40. YoloNASUpStage:
  41. out_channels: 192
  42. num_blocks: 2
  43. hidden_channels: 64
  44. width_mult: 1
  45. depth_mult: 1
  46. activation_type: relu
  47. reduce_channels: True
  48. neck2:
  49. YoloNASUpStage:
  50. out_channels: 96
  51. num_blocks: 2
  52. hidden_channels: 48
  53. width_mult: 1
  54. depth_mult: 1
  55. activation_type: relu
  56. reduce_channels: True
  57. neck3:
  58. YoloNASDownStage:
  59. out_channels: 192
  60. num_blocks: 2
  61. hidden_channels: 64
  62. activation_type: relu
  63. width_mult: 1
  64. depth_mult: 1
  65. neck4:
  66. YoloNASDownStage:
  67. out_channels: 384
  68. num_blocks: 2
  69. hidden_channels: 64
  70. activation_type: relu
  71. width_mult: 1
  72. depth_mult: 1
  73. heads:
  74. NDFLHeads:
  75. num_classes: 80
  76. reg_max: 16
  77. heads_list:
  78. - YoloNASDFLHead:
  79. inter_channels: 128
  80. width_mult: 0.5
  81. first_conv_group_size: 0
  82. stride: 8
  83. - YoloNASDFLHead:
  84. inter_channels: 256
  85. width_mult: 0.5
  86. first_conv_group_size: 0
  87. stride: 16
  88. - YoloNASDFLHead:
  89. inter_channels: 512
  90. width_mult: 0.5
  91. first_conv_group_size: 0
  92. stride: 32
  93. bn_eps: 1e-3
  94. bn_momentum: 0.03
  95. inplace_act: True
  96. _convert_: all
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. # YoloNAS-S Detection training on COCO2017 Dataset:
  2. # This training recipe is for demonstration purposes only. Pretrained models were trained using a different recipe.
  3. # So it will not be possible to reproduce the results of the pretrained models using this recipe.
  4. # Instructions:
  5. # 0. Make sure that the data is stored in dataset_params.dataset_dir or add "dataset_params.data_dir=<PATH-TO-DATASET>" at the end of the command below (feel free to check ReadMe)
  6. # 1. Move to the project root (where you will find the ReadMe and src folder)
  7. # 2. Run the command you want:
  8. # yolo_nas_s: python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_yolo_nas_s
  9. #
  10. defaults:
  11. - training_hyperparams: coco2017_yolo_nas_train_params
  12. - dataset_params: coco_detection_yolo_nas_dataset_params
  13. - arch_params: yolo_nas_s_arch_params
  14. - checkpoint_params: default_checkpoint_params
  15. - _self_
  16. - variable_setup
  17. train_dataloader: coco2017_train_yolo_nas
  18. val_dataloader: coco2017_val_yolo_nas
  19. load_checkpoint: False
  20. resume: False
  21. dataset_params:
  22. train_dataloader_params:
  23. batch_size: 32
  24. arch_params:
  25. num_classes: 80
  26. training_hyperparams:
  27. resume: ${resume}
  28. mixed_precision: True
  29. architecture: yolo_nas_s
  30. multi_gpu: DDP
  31. num_gpus: 8
  32. experiment_suffix: ""
  33. experiment_name: coco2017_${architecture}${experiment_suffix}
Discard
@@ -30,15 +30,15 @@ train_dataset_params:
         mixup_scale: [ 0.5, 1.5 ]         # random rescale range for the additional sample in mixup
         prob: 0.5                       # probability to apply per-sample mixup
         flip_prob: 0.5                  # probability to apply horizontal flip
-    - DetectionStandardizeImage:
-        max_value: 255.
     - DetectionPaddedRescale:
         input_dim: [640, 640]
         max_targets: 120
         pad_value: 114
+    - DetectionStandardize:
+        max_value: 255.
     - DetectionTargetsFormatTransform:
         max_targets: 256
-        output_format: LABEL_NORMALIZED_CXCYWH
+        output_format: LABEL_CXCYWH
 
   tight_box_rotation: False
   class_inclusion_list:
@@ -67,13 +67,13 @@ val_dataset_params:
     - DetectionPadToSize:
         output_size: [640, 640]
         pad_value: 114
-    - DetectionStandardizeImage:
+    - DetectionStandardize:
         max_value: 255.
     - DetectionImagePermute
     - DetectionTargetsFormatTransform:
         max_targets: 50
         input_dim: [640, 640]
-        output_format: LABEL_NORMALIZED_CXCYWH
+        output_format: LABEL_CXCYWH
   tight_box_rotation: False
   class_inclusion_list:
   max_num_samples:
@@ -83,6 +83,7 @@ val_dataloader_params:
   batch_size: 25
   num_workers: 8
   drop_last: False
+  shuffle: False
   pin_memory: True
   collate_fn:
     _target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
Discard
@@ -9,18 +9,27 @@ train_dataset_params:
   input_dim: [640, 640]
   cache_dir:
   cache: False
+  ignore_empty_annotations: False
   transforms:
+    - DetectionMosaic:
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
+        prob: 1.
     - DetectionRandomAffine:
         degrees: 0.                  # rotation degrees, randomly sampled from [-degrees, degrees]
         translate: 0.1                # image translation fraction
         scales: [ 0.5, 1.5 ]              # random rescale range (keeps size by padding/cropping) after mosaic transform.
         shear: 0.0                    # shear degrees, randomly sampled from [-degrees, degrees]
         target_size: ${dataset_params.train_dataset_params.input_dim}
-        filter_box_candidates: True   # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
+        filter_box_candidates: False  # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
         wh_thr: 2                     # edge size threshold when filter_box_candidates = True (pixels)
         area_thr: 0.1                 # threshold for area ratio between original image and the transformed one, when filter_box_candidates = True
         ar_thr: 20                    # aspect ratio threshold when filter_box_candidates = True
         border_value: 128
+#    - DetectionMixup:
+#        input_dim: ${dataset_params.train_dataset_params.input_dim}
+#        mixup_scale: [ 0.5, 1.5 ]         # random rescale range for the additional sample in mixup
+#        prob: 1.0                       # probability to apply per-sample mixup
+#        flip_prob: 0.5                  # probability to apply horizontal flip
     - DetectionHSV:
         prob: 1.0                       # probability to apply HSV transform
         hgain: 5                        # HSV transform hue gain (randomly sampled from [-hgain, hgain])
@@ -30,8 +39,11 @@ train_dataset_params:
         prob: 0.5                       # probability to apply horizontal flip
     - DetectionPaddedRescale:
         input_dim: ${dataset_params.train_dataset_params.input_dim}
-        max_targets: 120
+        max_targets: 300
+    - DetectionStandardize:
+        max_value: 255.
     - DetectionTargetsFormatTransform:
+        max_targets: 300
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         output_format: LABEL_CXCYWH
   tight_box_rotation: False
@@ -43,8 +55,8 @@ train_dataset_params:
 train_dataloader_params:
   shuffle: True
   batch_size: 16
-  num_workers: 0
-  sampler:
+  min_samples: 512
+  num_workers: 4
   drop_last: False
   pin_memory: True
   worker_init_fn:
@@ -60,11 +72,16 @@ val_dataset_params:
   input_dim: [640, 640]
   cache_dir:
   cache: False
+  ignore_empty_annotations: False
   transforms:
   - DetectionPaddedRescale:
       input_dim: ${dataset_params.val_dataset_params.input_dim}
+      max_targets: 300
+      pad_value: 114
+  - DetectionStandardize:
+      max_value: 255.
   - DetectionTargetsFormatTransform:
-      max_targets: 50
+      max_targets: 300
       input_dim: ${dataset_params.val_dataset_params.input_dim}
       output_format: LABEL_CXCYWH
   tight_box_rotation: False
@@ -74,10 +91,10 @@ val_dataset_params:
   verbose: 0
 
 val_dataloader_params:
-  batch_size: 64
-  num_workers: 0
-  sampler:
+  batch_size: 32
+  num_workers: 4
   drop_last: False
+  shuffle: False
   pin_memory: True
   collate_fn: # collate function for valset
     _target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. # A recipe to fine-tune YoloNAS on Roboflow datasets.
  2. # Checkout the datasets at https://universe.roboflow.com/roboflow-100?ref=blog.roboflow.com
  3. #
  4. # `dataset_name` refers to the official name of the dataset.
  5. # You can find it in the url of the dataset: https://universe.roboflow.com/roboflow-100/digits-t2eg6 -> digits-t2eg6
  6. #
  7. # Example: python -m super_gradients.train_from_recipe --config-name=roboflow_yolo_nas_m dataset_name=digits-t2eg6
  8. defaults:
  9. - training_hyperparams: coco2017_yolo_nas_train_params
  10. - dataset_params: roboflow_detection_dataset_params
  11. - checkpoint_params: default_checkpoint_params
  12. - arch_params: yolo_nas_m_arch_params
  13. - _self_
  14. - variable_setup
  15. train_dataloader: roboflow_train_yolox
  16. val_dataloader: roboflow_val_yolox
  17. dataset_name: ??? # Placeholder for the name of the dataset you want to use (e.g. "digits-t2eg6")
  18. dataset_params:
  19. dataset_name: ${dataset_name}
  20. train_dataloader_params:
  21. batch_size: 12
  22. val_dataloader_params:
  23. batch_size: 16
  24. num_classes: ${roboflow_dataset_num_classes:${dataset_name}}
  25. architecture: yolo_nas_m
  26. arch_params:
  27. num_classes: ${num_classes}
  28. load_checkpoint: False
  29. checkpoint_params:
  30. pretrained_weights: coco
  31. result_path: # By defaults saves results in checkpoints directory
  32. resume: False
  33. training_hyperparams:
  34. resume: ${resume}
  35. zero_weight_decay_on_bias_and_bn: True
  36. lr_warmup_epochs: 3
  37. warmup_mode: linear_epoch_step
  38. initial_lr: 4e-4
  39. cosine_final_lr_ratio: 0.1
  40. optimizer_params:
  41. weight_decay: 0.0001
  42. ema: True
  43. ema_params:
  44. decay: 0.9
  45. max_epochs: 100
  46. mixed_precision: True
  47. criterion_params:
  48. num_classes: ${num_classes}
  49. phase_callbacks: []
  50. loss:
  51. ppyoloe_loss:
  52. num_classes: ${num_classes}
  53. reg_max: 16
  54. valid_metrics_list:
  55. - DetectionMetrics_050:
  56. score_thres: 0.1
  57. top_k_predictions: 300
  58. num_cls: ${num_classes}
  59. normalize_targets: True
  60. post_prediction_callback:
  61. _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
  62. score_threshold: 0.01
  63. nms_top_k: 1000
  64. max_predictions: 300
  65. nms_threshold: 0.7
  66. metric_to_watch: 'mAP@0.50'
  67. multi_gpu: Off
  68. num_gpus: 1
  69. experiment_suffix: ""
  70. experiment_name: ${architecture}_roboflow_${dataset_name}${experiment_suffix}
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. # A recipe to fine-tune YoloNAS on Roboflow datasets.
  2. # Checkout the datasets at https://universe.roboflow.com/roboflow-100?ref=blog.roboflow.com
  3. #
  4. # `dataset_name` refers to the official name of the dataset.
  5. # You can find it in the url of the dataset: https://universe.roboflow.com/roboflow-100/digits-t2eg6 -> digits-t2eg6
  6. #
  7. # Example: python -m super_gradients.train_from_recipe --config-name=roboflow_yolo_nas_s dataset_name=digits-t2eg6
  8. defaults:
  9. - training_hyperparams: coco2017_yolo_nas_train_params
  10. - dataset_params: roboflow_detection_dataset_params
  11. - checkpoint_params: default_checkpoint_params
  12. - arch_params: yolo_nas_s_arch_params
  13. - _self_
  14. - variable_setup
  15. train_dataloader: roboflow_train_yolox
  16. val_dataloader: roboflow_val_yolox
  17. dataset_name: ??? # Placeholder for the name of the dataset you want to use (e.g. "digits-t2eg6")
  18. dataset_params:
  19. dataset_name: ${dataset_name}
  20. train_dataloader_params:
  21. batch_size: 16
  22. val_dataloader_params:
  23. batch_size: 16
  24. num_classes: ${roboflow_dataset_num_classes:${dataset_name}}
  25. architecture: yolo_nas_s
  26. arch_params:
  27. num_classes: ${num_classes}
  28. load_checkpoint: False
  29. checkpoint_params:
  30. pretrained_weights: coco
  31. result_path: # By defaults saves results in checkpoints directory
  32. resume: False
  33. training_hyperparams:
  34. resume: ${resume}
  35. zero_weight_decay_on_bias_and_bn: True
  36. lr_warmup_epochs: 3
  37. warmup_mode: linear_epoch_step
  38. initial_lr: 5e-4
  39. cosine_final_lr_ratio: 0.1
  40. optimizer_params:
  41. weight_decay: 0.0001
  42. ema: True
  43. ema_params:
  44. decay: 0.9
  45. max_epochs: 100
  46. mixed_precision: True
  47. criterion_params:
  48. num_classes: ${num_classes}
  49. phase_callbacks: []
  50. loss:
  51. ppyoloe_loss:
  52. num_classes: ${num_classes}
  53. reg_max: 16
  54. valid_metrics_list:
  55. - DetectionMetrics_050:
  56. score_thres: 0.1
  57. top_k_predictions: 300
  58. num_cls: ${num_classes}
  59. normalize_targets: True
  60. post_prediction_callback:
  61. _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
  62. score_threshold: 0.01
  63. nms_top_k: 1000
  64. max_predictions: 300
  65. nms_threshold: 0.7
  66. metric_to_watch: 'mAP@0.50'
  67. multi_gpu: Off
  68. num_gpus: 1
  69. experiment_suffix: ""
  70. experiment_name: ${architecture}_roboflow_${dataset_name}${experiment_suffix}
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
  1. defaults:
  2. - roboflow_yolo_nas_s
  3. - quantization_params: default_quantization_params
  4. - _self_
  5. checkpoint_params:
  6. checkpoint_path: ???
  7. strict_load: no_key_matching
  8. pre_launch_callbacks_list:
  9. - QATRecipeModificationCallback:
  10. batch_size_divisor: 2
  11. max_epochs_divisor: 10
  12. lr_decay_factor: 0.01
  13. warmup_epochs_divisor: 10
  14. cosine_final_lr_ratio: 0.01
  15. disable_phase_callbacks: True
  16. disable_augmentations: False
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  1. defaults:
  2. - default_train_params
  3. max_epochs: 300
  4. warmup_mode: "linear_batch_step"
  5. warmup_initial_lr: 1e-6
  6. lr_warmup_steps: 1000
  7. lr_warmup_epochs: 0
  8. initial_lr: 2e-4
  9. lr_mode: cosine
  10. cosine_final_lr_ratio: 0.1
  11. zero_weight_decay_on_bias_and_bn: True
  12. batch_accumulate: 1
  13. save_ckpt_epoch_list: [100, 200, 250]
  14. loss:
  15. ppyoloe_loss:
  16. use_static_assigner: False
  17. num_classes: ${arch_params.num_classes}
  18. reg_max: 16
  19. optimizer: AdamW
  20. optimizer_params:
  21. weight_decay: 0.00001
  22. ema: True
  23. ema_params:
  24. decay: 0.9997
  25. decay_type: threshold
  26. mixed_precision: False
  27. sync_bn: True
  28. valid_metrics_list:
  29. - DetectionMetrics:
  30. score_thres: 0.1
  31. top_k_predictions: 300
  32. num_cls: ${arch_params.num_classes}
  33. normalize_targets: True
  34. post_prediction_callback:
  35. _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
  36. score_threshold: 0.01
  37. nms_top_k: 1000
  38. max_predictions: 300
  39. nms_threshold: 0.7
  40. pre_prediction_callback:
  41. metric_to_watch: 'mAP@0.50:0.95'
  42. greater_metric_to_watch_is_better: True
  43. _convert_: all
Discard
@@ -9,8 +9,8 @@ from .dataloaders import (
     coco2017_val_ppyoloe,
     coco2017_pose_train,
     coco2017_pose_val,
-    coco2017_train_deci_yolo,
-    coco2017_val_deci_yolo,
+    coco2017_train_yolo_nas,
+    coco2017_val_yolo_nas,
     imagenet_train,
     imagenet_val,
     imagenet_efficientnet_train,
@@ -68,8 +68,8 @@ __all__ = [
     "coco2017_val_ppyoloe",
     "coco2017_pose_train",
     "coco2017_pose_val",
-    "coco2017_train_deci_yolo",
-    "coco2017_val_deci_yolo",
+    "coco2017_train_yolo_nas",
+    "coco2017_val_yolo_nas",
     "imagenet_train",
     "imagenet_val",
     "imagenet_efficientnet_train",
Discard
@@ -172,10 +172,10 @@ def coco2017_val(dataset_params: Dict = None, dataloader_params: Dict = None) ->
     )
 
 
-@register_dataloader(Dataloaders.COCO2017_TRAIN_DECIYOLO)
-def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
+@register_dataloader(Dataloaders.COCO2017_TRAIN_YOLO_NAS)
+def coco2017_train_yolo_nas(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
     return get_data_loader(
-        config_name="coco_detection_deci_yolo_dataset_params",
+        config_name="coco_detection_yolo_nas_dataset_params",
         dataset_cls=COCODetectionDataset,
         train=True,
         dataset_params=dataset_params,
@@ -183,10 +183,10 @@ def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dic
     )
 
 
-@register_dataloader(Dataloaders.COCO2017_VAL_DECIYOLO)
-def coco2017_val_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
+@register_dataloader(Dataloaders.COCO2017_VAL_YOLO_NAS)
+def coco2017_val_yolo_nas(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
     return get_data_loader(
-        config_name="coco_detection_deci_yolo_dataset_params",
+        config_name="coco_detection_yolo_nas_dataset_params",
         dataset_cls=COCODetectionDataset,
         train=False,
         dataset_params=dataset_params,
Discard
@@ -33,7 +33,7 @@ DATASETS_METADATA = {
     "underwater-objects-5v7p8": {"category": "underwater", "train": 5320, "test": 760, "valid": 1520, "size": 7600, "num_classes": 5, "num_classes_found": 5},
     "coral-lwptl": {"category": "underwater", "train": 427, "test": 74, "valid": 93, "size": 594, "num_classes": 14, "num_classes_found": 14},
     "tweeter-posts": {"category": "documents", "train": 87, "test": 9, "valid": 21, "size": 117, "num_classes": 2, "num_classes_found": 2},
-    "tweeter-profile": {"category": "documents", "train": 425, "test": 61, "valid": 121, "size": 607, "num_classes": 1, "num_classes_found": 0},
+    "tweeter-profile": {"category": "documents", "train": 425, "test": 61, "valid": 121, "size": 607, "num_classes": 1, "num_classes_found": 1},
     "document-parts": {"category": "documents", "train": 906, "test": 150, "valid": 318, "size": 1374, "num_classes": 2, "num_classes_found": 2},
     "activity-diagrams-qdobr": {"category": "documents", "train": 259, "test": 45, "valid": 74, "size": 378, "num_classes": 19, "num_classes_found": 19},
     "signatures-xc8up": {"category": "documents", "train": 257, "test": 37, "valid": 74, "size": 368, "num_classes": 1, "num_classes_found": 1},
@@ -148,7 +148,7 @@ _NUM_CLASSES_FOUND = {
     "underwater-objects-5v7p8": 5,
     "coral-lwptl": 14,
     "tweeter-posts": 2,
-    "tweeter-profile": 0,
+    "tweeter-profile": 1,
     "document-parts": 2,
     "activity-diagrams-qdobr": 19,
     "signatures-xc8up": 1,
Discard
@@ -62,13 +62,26 @@ from super_gradients.training.models.classification_models.vgg import VGG
 from super_gradients.training.models.classification_models.vit import ViT, ViTBase, ViTLarge, ViTHuge
 
 # Detection models
-from super_gradients.training.models.detection_models.csp_darknet53 import CSPDarknet53
-from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X
+from super_gradients.training.models.detection_models.csp_darknet53 import CSPDarknet53, SPP
+from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X
 from super_gradients.training.models.detection_models.darknet53 import Darknet53, Darknet53Base
 from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
 from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X, CustomYoloX
 from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
+from super_gradients.training.models.detection_models.yolo_nas import (
+    YoloNASStage,
+    YoloNASStem,
+    YoloNASDownStage,
+    YoloNASUpStage,
+    YoloNASBottleneck,
+    YoloNASDFLHead,
+    NDFLHeads,
+    YoloNASPANNeckWithC2,
+    YoloNAS_S,
+    YoloNAS_M,
+    YoloNAS_L,
+)
 
 # Segmentation models
 from super_gradients.training.models.segmentation_models.shelfnet import (
@@ -96,7 +109,6 @@ from super_gradients.training.models.segmentation_models.stdc import (
     STDCSegmentationBase,
     CustomSTDCSegmentation,
 )
-from super_gradients.training.models.segmentation_models.segformer import SegFormerB0, SegFormerB1, SegFormerB2, SegFormerB3, SegFormerB4, SegFormerB5
 
 # Pose estimation
 from super_gradients.training.models.pose_estimation_models.pose_ppyolo import PosePPYoloL
@@ -116,6 +128,18 @@ from super_gradients.common.object_names import Models
 from super_gradients.common.registry.registry import ARCHITECTURES
 
 __all__ = [
+    "SPP",
+    "YoloNAS_S",
+    "YoloNAS_M",
+    "YoloNAS_L",
+    "YoloNASStage",
+    "YoloNASUpStage",
+    "YoloNASStem",
+    "YoloNASDownStage",
+    "YoloNASDFLHead",
+    "YoloNASBottleneck",
+    "NDFLHeads",
+    "YoloNASPANNeckWithC2",
     "SgModule",
     "Beit",
     "BeitLargePatch16_224",
@@ -259,10 +283,4 @@ __all__ = [
     "ARCHITECTURES",
     "Models",
     "user_models",
-    "SegFormerB0",
-    "SegFormerB1",
-    "SegFormerB2",
-    "SegFormerB3",
-    "SegFormerB4",
-    "SegFormerB5",
 ]
Discard
@@ -7,9 +7,11 @@ from typing import Tuple, Type
 import torch
 import torch.nn as nn
 
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.common.object_names import Models
-from super_gradients.common.registry.registry import register_model
-from super_gradients.modules import Residual, Conv
+from super_gradients.common.registry.registry import register_model, register_detection_module
+from super_gradients.modules import Residual, Conv, BaseDetectionModule
 from super_gradients.modules.utils import width_multiplier
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.utils.utils import get_param, HpmStruct
@@ -127,13 +129,16 @@ class BottleneckCSP(nn.Module):
         return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
 
 
-class SPP(nn.Module):
+@register_detection_module()
+class SPP(BaseDetectionModule):
     # SPATIAL PYRAMID POOLING LAYER
-    def __init__(self, input_channels, output_channels, k: Tuple, activation_type: Type[nn.Module]):
-        super().__init__()
+    @resolve_param("activation_type", ActivationsTypeFactory())
+    def __init__(self, in_channels, output_channels, k: Tuple, activation_type: Type[nn.Module]):
+        super().__init__(in_channels)
+        self._output_channels = output_channels
 
-        hidden_channels = input_channels // 2
-        self.cv1 = Conv(input_channels, hidden_channels, 1, 1, activation_type)
+        hidden_channels = in_channels // 2
+        self.cv1 = Conv(in_channels, hidden_channels, 1, 1, activation_type)
         self.cv2 = Conv(hidden_channels * (len(k) + 1), output_channels, 1, 1, activation_type)
         self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
 
@@ -141,6 +146,13 @@ class SPP(nn.Module):
         x = self.cv1(x)
         return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
 
+    @property
+    def out_channels(self):
+        """
+        :return: channels of tensor(s) that will be returned by a module  in forward
+        """
+        return self._output_channels
+
 
 class ViewModule(nn.Module):
     """
Discard
@@ -12,6 +12,8 @@ from omegaconf import DictConfig
 
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.processing_factory import ProcessingFactory
+from super_gradients.module_interfaces import SupportsReplaceNumClasses
+from super_gradients.modules.head_replacement_utils import replace_num_classes_with_random_weights
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.models.sg_module import SgModule
 import super_gradients.common.factories.detection_modules_factory as det_factory
@@ -102,6 +104,8 @@ class CustomizableDetector(SgModule):
             raise ValueError("At least one of new_num_classes, new_head must be given to replace output layer.")
         if new_head is not None:
             self.heads = new_head
+        elif isinstance(self.heads, SupportsReplaceNumClasses):
+            self.heads.replace_num_classes(new_num_classes, replace_num_classes_with_random_weights)
         else:
             factory = det_factory.DetectionModulesFactory()
             self.heads_params = factory.insert_module_param(self.heads_params, "num_classes", new_num_classes)
Discard
@@ -1,4 +1,4 @@
-from .pp_yolo_e import PPYoloE
+from .pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X
 from .post_prediction_callback import PPYoloEPostPredictionCallback
 
-__all__ = ["PPYoloE", "PPYoloEPostPredictionCallback"]
+__all__ = ["PPYoloE", "PPYoloEPostPredictionCallback", "PPYoloE_L", "PPYoloE_M", "PPYoloE_S", "PPYoloE_X"]
Discard
@@ -10,10 +10,10 @@ from super_gradients.common.factories.activations_type_factory import Activation
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBasicBlock
 from super_gradients.modules import ConvBNAct
 
-__all__ = ["CustomCSPPAN"]
+__all__ = ["PPYoloECSPPAN"]
 
 
-class SPP(nn.Module):
+class PPYoloESPP(nn.Module):
     def __init__(
         self,
         in_channels: int,
@@ -52,7 +52,7 @@ class CSPStage(nn.Module):
         for i in range(n):
             convs.append((str(i), CSPResNetBasicBlock(next_ch_in, ch_mid, activation_type=activation_type, use_residual_connection=False)))
             if i == (n - 1) // 2 and spp:
-                convs.append(("spp", SPP(ch_mid, ch_mid, 1, (5, 9, 13), activation_type=activation_type)))
+                convs.append(("spp", PPYoloESPP(ch_mid, ch_mid, 1, (5, 9, 13), activation_type=activation_type)))
             next_ch_in = ch_mid
 
         self.convs = nn.Sequential(collections.OrderedDict(convs))
@@ -68,7 +68,7 @@ class CSPStage(nn.Module):
 
 
 @register_detection_module()
-class CustomCSPPAN(nn.Module):
+class PPYoloECSPPAN(nn.Module):
     @resolve_param("activation", ActivationsTypeFactory())
     def __init__(
         self,
Discard
@@ -1,6 +1,7 @@
 from typing import Union, Optional, List
 
 from torch import Tensor
+
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.processing_factory import ProcessingFactory
 from super_gradients.common.registry.registry import register_model
@@ -8,7 +9,7 @@ from super_gradients.common.object_names import Models
 from super_gradients.modules import RepVGGBlock
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
-from super_gradients.training.models.detection_models.pp_yolo_e.pan import CustomCSPPAN
+from super_gradients.training.models.detection_models.pp_yolo_e.pan import PPYoloECSPPAN
 from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import PPYOLOEHead
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.models.arch_params_factory import get_arch_params
@@ -26,7 +27,7 @@ class PPYoloE(SgModule):
             arch_params = arch_params.to_dict()
 
         self.backbone = CSPResNetBackbone(**arch_params["backbone"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
-        self.neck = CustomCSPPAN(**arch_params["neck"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
+        self.neck = PPYoloECSPPAN(**arch_params["neck"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
         self.head = PPYOLOEHead(**arch_params["head"], width_mult=arch_params["width_mult"], num_classes=arch_params["num_classes"])
 
         self._class_names: Optional[List[str]] = None
Discard
@@ -175,11 +175,12 @@ class PPYOLOEHead(nn.Module):
     @torch.jit.ignore
     def replace_num_classes(self, num_classes: int):
         bias_cls = bias_init_with_prob(0.01)
+        device = self.pred_cls[0].weight.device
         self.pred_cls = nn.ModuleList()
         self.num_classes = num_classes
 
         for in_c in self.in_channels:
-            predict_layer = nn.Conv2d(in_c, num_classes, 3, padding=1)
+            predict_layer = nn.Conv2d(in_c, num_classes, 3, padding=1, device=device)
             torch.nn.init.constant_(predict_layer.weight, 0.0)
             torch.nn.init.constant_(predict_layer.bias, bias_cls)
             self.pred_cls.append(predict_layer)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  1. from super_gradients.training.models.detection_models.yolo_nas.dfl_heads import YoloNASDFLHead, NDFLHeads
  2. from super_gradients.training.models.detection_models.yolo_nas.panneck import YoloNASPANNeckWithC2
  3. from super_gradients.training.models.detection_models.yolo_nas.yolo_stages import (
  4. YoloNASStage,
  5. YoloNASStem,
  6. YoloNASDownStage,
  7. YoloNASUpStage,
  8. YoloNASBottleneck,
  9. )
  10. from super_gradients.training.models.detection_models.yolo_nas.yolo_nas_variants import YoloNAS_S, YoloNAS_M, YoloNAS_L
  11. __all__ = [
  12. "YoloNASBottleneck",
  13. "YoloNASUpStage",
  14. "YoloNASDownStage",
  15. "YoloNASStem",
  16. "YoloNASStage",
  17. "NDFLHeads",
  18. "YoloNASDFLHead",
  19. "YoloNASPANNeckWithC2",
  20. "YoloNAS_S",
  21. "YoloNAS_M",
  22. "YoloNAS_L",
  23. ]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
  1. import math
  2. from typing import Tuple, Union, List, Callable, Optional
  3. import torch
  4. from omegaconf import DictConfig
  5. from torch import nn, Tensor
  6. import super_gradients.common.factories.detection_modules_factory as det_factory
  7. from super_gradients.common.registry import register_detection_module
  8. from super_gradients.modules import ConvBNReLU
  9. from super_gradients.modules.base_modules import BaseDetectionModule
  10. from super_gradients.module_interfaces import SupportsReplaceNumClasses
  11. from super_gradients.modules.utils import width_multiplier
  12. from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import generate_anchors_for_grid_cell
  13. from super_gradients.training.utils import HpmStruct, torch_version_is_greater_or_equal
  14. from super_gradients.training.utils.bbox_utils import batch_distance2bbox
  15. @register_detection_module()
  16. class YoloNASDFLHead(BaseDetectionModule, SupportsReplaceNumClasses):
  17. def __init__(self, in_channels: int, inter_channels: int, width_mult: float, first_conv_group_size: int, num_classes: int, stride: int, reg_max: int):
  18. """
  19. Initialize the YoloNASDFLHead
  20. :param in_channels: Input channels
  21. :param inter_channels: Intermediate number of channels
  22. :param width_mult: Width multiplier
  23. :param first_conv_group_size: Group size
  24. :param num_classes: Number of detection classes
  25. :param stride: Output stride for this head
  26. :param reg_max: Number of bins in the regression head
  27. """
  28. super().__init__(in_channels)
  29. inter_channels = width_multiplier(inter_channels, width_mult, 8)
  30. if first_conv_group_size == 0:
  31. groups = 0
  32. elif first_conv_group_size == -1:
  33. groups = 1
  34. else:
  35. groups = inter_channels // first_conv_group_size
  36. self.num_classes = num_classes
  37. self.stem = ConvBNReLU(in_channels, inter_channels, kernel_size=1, stride=1, padding=0, bias=False)
  38. first_cls_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else []
  39. self.cls_convs = nn.Sequential(*first_cls_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False))
  40. first_reg_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else []
  41. self.reg_convs = nn.Sequential(*first_reg_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False))
  42. self.cls_pred = nn.Conv2d(inter_channels, self.num_classes, 1, 1, 0)
  43. self.reg_pred = nn.Conv2d(inter_channels, 4 * (reg_max + 1), 1, 1, 0)
  44. self.grid = torch.zeros(1)
  45. self.stride = stride
  46. self.prior_prob = 1e-2
  47. self._initialize_biases()
  48. def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]):
  49. self.cls_pred = compute_new_weights_fn(self.cls_pred, num_classes)
  50. self.num_classes = num_classes
  51. @property
  52. def out_channels(self):
  53. return None
  54. def forward(self, x):
  55. x = self.stem(x)
  56. cls_feat = self.cls_convs(x)
  57. cls_output = self.cls_pred(cls_feat)
  58. reg_feat = self.reg_convs(x)
  59. reg_output = self.reg_pred(reg_feat)
  60. return reg_output, cls_output
  61. def _initialize_biases(self):
  62. prior_bias = -math.log((1 - self.prior_prob) / self.prior_prob)
  63. torch.nn.init.constant_(self.cls_pred.bias, prior_bias)
  64. @staticmethod
  65. def _make_grid(nx=20, ny=20):
  66. if torch_version_is_greater_or_equal(1, 10):
  67. # https://github.com/pytorch/pytorch/issues/50276
  68. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij")
  69. else:
  70. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  71. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  72. @register_detection_module()
  73. class NDFLHeads(BaseDetectionModule, SupportsReplaceNumClasses):
  74. def __init__(
  75. self,
  76. num_classes: int,
  77. in_channels: Tuple[int, int, int],
  78. heads_list: Union[str, HpmStruct, DictConfig],
  79. grid_cell_scale: float = 5.0,
  80. grid_cell_offset: float = 0.5,
  81. reg_max: int = 16,
  82. eval_size: Optional[Tuple[int, int]] = None,
  83. width_mult: float = 1.0,
  84. ):
  85. """
  86. Initializes the NDFLHeads module.
  87. :param num_classes: Number of detection classes
  88. :param in_channels: Number of channels for each feature map (See width_mult)
  89. :param grid_cell_scale:
  90. :param grid_cell_offset:
  91. :param reg_max: Number of bins in the regression head
  92. :param eval_size: (rows, cols) Size of the image for evaluation. Setting this value can be beneficial for inference speed,
  93. since anchors will not be regenerated for each forward call.
  94. :param width_mult: A scaling factor applied to in_channels.
  95. """
  96. super(NDFLHeads, self).__init__(in_channels)
  97. in_channels = [max(round(c * width_mult), 1) for c in in_channels]
  98. self.in_channels = tuple(in_channels)
  99. self.num_classes = num_classes
  100. self.grid_cell_scale = grid_cell_scale
  101. self.grid_cell_offset = grid_cell_offset
  102. self.reg_max = reg_max
  103. self.eval_size = eval_size
  104. # Do not apply quantization to this tensor
  105. proj = torch.linspace(0, self.reg_max, self.reg_max + 1).reshape([1, self.reg_max + 1, 1, 1])
  106. self.register_buffer("proj_conv", proj, persistent=False)
  107. self._init_weights()
  108. factory = det_factory.DetectionModulesFactory()
  109. heads_list = self._pass_args(heads_list, factory, num_classes, reg_max)
  110. self.num_heads = len(heads_list)
  111. fpn_strides: List[int] = []
  112. for i in range(self.num_heads):
  113. new_head = factory.get(factory.insert_module_param(heads_list[i], "in_channels", in_channels[i]))
  114. fpn_strides.append(new_head.stride)
  115. setattr(self, f"head{i + 1}", new_head)
  116. self.fpn_strides = tuple(fpn_strides)
  117. def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]):
  118. for i in range(self.num_heads):
  119. head = getattr(self, f"head{i + 1}")
  120. head.replace_num_classes(num_classes, compute_new_weights_fn)
  121. self.num_classes = num_classes
  122. @staticmethod
  123. def _pass_args(heads_list, factory, num_classes, reg_max):
  124. for i in range(len(heads_list)):
  125. heads_list[i] = factory.insert_module_param(heads_list[i], "num_classes", num_classes)
  126. heads_list[i] = factory.insert_module_param(heads_list[i], "reg_max", reg_max)
  127. return heads_list
  128. @torch.jit.ignore
  129. def cache_anchors(self, input_size: Tuple[int, int]):
  130. self.eval_size = input_size
  131. anchor_points, stride_tensor = self._generate_anchors()
  132. self.anchor_points = anchor_points
  133. self.stride_tensor = stride_tensor
  134. @torch.jit.ignore
  135. def _init_weights(self):
  136. if self.eval_size:
  137. anchor_points, stride_tensor = self._generate_anchors()
  138. self.anchor_points = anchor_points
  139. self.stride_tensor = stride_tensor
  140. @torch.jit.ignore
  141. def forward_train(self, feats: Tuple[Tensor, ...]):
  142. anchors, anchor_points, num_anchors_list, stride_tensor = generate_anchors_for_grid_cell(
  143. feats, self.fpn_strides, self.grid_cell_scale, self.grid_cell_offset
  144. )
  145. cls_score_list, reg_distri_list = [], []
  146. for i, feat in enumerate(feats):
  147. reg_distri, cls_logit = getattr(self, f"head{i + 1}")(feat)
  148. # cls and reg
  149. # Note we don't apply sigmoid on class predictions to ensure good numerical stability at loss computation
  150. cls_score_list.append(torch.permute(cls_logit.flatten(2), [0, 2, 1]))
  151. reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1]))
  152. cls_score_list = torch.cat(cls_score_list, dim=1)
  153. reg_distri_list = torch.cat(reg_distri_list, dim=1)
  154. return cls_score_list, reg_distri_list, anchors, anchor_points, num_anchors_list, stride_tensor
  155. def forward_eval(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]:
  156. cls_score_list, reg_distri_list, reg_dist_reduced_list = [], [], []
  157. for i, feat in enumerate(feats):
  158. b, _, h, w = feat.shape
  159. height_mul_width = h * w
  160. reg_distri, cls_logit = getattr(self, f"head{i + 1}")(feat)
  161. reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1]))
  162. reg_dist_reduced = torch.permute(reg_distri.reshape([-1, 4, self.reg_max + 1, height_mul_width]), [0, 2, 3, 1])
  163. reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1)
  164. # cls and reg
  165. cls_score_list.append(cls_logit.reshape([b, self.num_classes, height_mul_width]))
  166. reg_dist_reduced_list.append(reg_dist_reduced)
  167. cls_score_list = torch.cat(cls_score_list, dim=-1) # [B, C, Anchors]
  168. cls_score_list = torch.permute(cls_score_list, [0, 2, 1]) # # [B, Anchors, C]
  169. reg_distri_list = torch.cat(reg_distri_list, dim=1) # [B, Anchors, 4 * (self.reg_max + 1)]
  170. reg_dist_reduced_list = torch.cat(reg_dist_reduced_list, dim=1) # [B, Anchors, 4]
  171. # Decode bboxes
  172. # Note in eval mode, anchor_points_inference is different from anchor_points computed on train
  173. if self.eval_size:
  174. anchor_points_inference, stride_tensor = self.anchor_points, self.stride_tensor
  175. else:
  176. anchor_points_inference, stride_tensor = self._generate_anchors(feats)
  177. pred_scores = cls_score_list.sigmoid()
  178. pred_bboxes = batch_distance2bbox(anchor_points_inference, reg_dist_reduced_list) * stride_tensor # [B, Anchors, 4]
  179. decoded_predictions = pred_bboxes, pred_scores
  180. if torch.jit.is_tracing():
  181. return decoded_predictions
  182. anchors, anchor_points, num_anchors_list, _ = generate_anchors_for_grid_cell(feats, self.fpn_strides, self.grid_cell_scale, self.grid_cell_offset)
  183. raw_predictions = cls_score_list, reg_distri_list, anchors, anchor_points, num_anchors_list, stride_tensor
  184. return decoded_predictions, raw_predictions
  185. @property
  186. def out_channels(self):
  187. return None
  188. def forward(self, feats: Tuple[Tensor]):
  189. if self.training:
  190. return self.forward_train(feats)
  191. else:
  192. return self.forward_eval(feats)
  193. def _generate_anchors(self, feats=None, dtype=torch.float):
  194. # just use in eval time
  195. anchor_points = []
  196. stride_tensor = []
  197. for i, stride in enumerate(self.fpn_strides):
  198. if feats is not None:
  199. _, _, h, w = feats[i].shape
  200. else:
  201. h = int(self.eval_size[0] / stride)
  202. w = int(self.eval_size[1] / stride)
  203. shift_x = torch.arange(end=w) + self.grid_cell_offset
  204. shift_y = torch.arange(end=h) + self.grid_cell_offset
  205. if torch_version_is_greater_or_equal(1, 10):
  206. shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
  207. else:
  208. shift_y, shift_x = torch.meshgrid(shift_y, shift_x)
  209. anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype)
  210. anchor_points.append(anchor_point.reshape([-1, 2]))
  211. stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype))
  212. anchor_points = torch.cat(anchor_points)
  213. stride_tensor = torch.cat(stride_tensor)
  214. if feats is not None:
  215. anchor_points = anchor_points.to(feats[0].device)
  216. stride_tensor = stride_tensor.to(feats[0].device)
  217. return anchor_points, stride_tensor
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. from typing import Union, List, Tuple
  2. from omegaconf import DictConfig
  3. from torch import Tensor
  4. from super_gradients.common.registry import register_detection_module
  5. from super_gradients.modules.detection_modules import BaseDetectionModule
  6. from super_gradients.training.utils.utils import HpmStruct
  7. import super_gradients.common.factories.detection_modules_factory as det_factory
  8. @register_detection_module("YoloNASPANNeckWithC2")
  9. class YoloNASPANNeckWithC2(BaseDetectionModule):
  10. """
  11. A PAN (path aggregation network) neck with 4 stages (2 up-sampling and 2 down-sampling stages)
  12. where the up-sampling stages include a higher resolution skip
  13. Returns outputs of neck stage 2, stage 3, stage 4
  14. """
  15. def __init__(
  16. self,
  17. in_channels: List[int],
  18. neck1: Union[str, HpmStruct, DictConfig],
  19. neck2: Union[str, HpmStruct, DictConfig],
  20. neck3: Union[str, HpmStruct, DictConfig],
  21. neck4: Union[str, HpmStruct, DictConfig],
  22. ):
  23. """
  24. Initialize the PAN neck
  25. :param in_channels: Input channels of the 4 feature maps from the backbone
  26. :param neck1: First neck stage config
  27. :param neck2: Second neck stage config
  28. :param neck3: Third neck stage config
  29. :param neck4: Fourth neck stage config
  30. """
  31. super().__init__(in_channels)
  32. c2_out_channels, c3_out_channels, c4_out_channels, c5_out_channels = in_channels
  33. factory = det_factory.DetectionModulesFactory()
  34. self.neck1 = factory.get(factory.insert_module_param(neck1, "in_channels", [c5_out_channels, c4_out_channels, c3_out_channels]))
  35. self.neck2 = factory.get(factory.insert_module_param(neck2, "in_channels", [self.neck1.out_channels[1], c3_out_channels, c2_out_channels]))
  36. self.neck3 = factory.get(factory.insert_module_param(neck3, "in_channels", [self.neck2.out_channels[1], self.neck2.out_channels[0]]))
  37. self.neck4 = factory.get(factory.insert_module_param(neck4, "in_channels", [self.neck3.out_channels, self.neck1.out_channels[0]]))
  38. self._out_channels = [
  39. self.neck2.out_channels[1],
  40. self.neck3.out_channels,
  41. self.neck4.out_channels,
  42. ]
  43. @property
  44. def out_channels(self):
  45. return self._out_channels
  46. def forward(self, inputs: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
  47. c2, c3, c4, c5 = inputs
  48. x_n1_inter, x = self.neck1([c5, c4, c3])
  49. x_n2_inter, p3 = self.neck2([x, c3, c2])
  50. p4 = self.neck3([p3, x_n2_inter])
  51. p5 = self.neck4([p4, x_n1_inter])
  52. return p3, p4, p5
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. import copy
  2. from typing import Union
  3. from omegaconf import DictConfig
  4. from super_gradients.common.object_names import Models
  5. from super_gradients.common.registry import register_model
  6. from super_gradients.training.models.arch_params_factory import get_arch_params
  7. from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
  8. from super_gradients.training.utils import HpmStruct, get_param
  9. from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
  10. @register_model(Models.YOLO_NAS_S)
  11. class YoloNAS_S(CustomizableDetector):
  12. def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
  13. default_arch_params = get_arch_params("yolo_nas_s_arch_params")
  14. merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
  15. merged_arch_params.override(**arch_params.to_dict())
  16. super().__init__(
  17. backbone=merged_arch_params.backbone,
  18. neck=merged_arch_params.neck,
  19. heads=merged_arch_params.heads,
  20. num_classes=get_param(merged_arch_params, "num_classes", None),
  21. in_channels=in_channels,
  22. bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
  23. bn_eps=get_param(merged_arch_params, "bn_eps", None),
  24. inplace_act=get_param(merged_arch_params, "inplace_act", None),
  25. )
  26. @staticmethod
  27. def get_post_prediction_callback(conf: float, iou: float) -> PPYoloEPostPredictionCallback:
  28. return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
  29. @property
  30. def num_classes(self):
  31. return self.heads.num_classes
  32. @register_model(Models.YOLO_NAS_M)
  33. class YoloNAS_M(CustomizableDetector):
  34. def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
  35. default_arch_params = get_arch_params("yolo_nas_m_arch_params")
  36. merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
  37. merged_arch_params.override(**arch_params.to_dict())
  38. super().__init__(
  39. backbone=merged_arch_params.backbone,
  40. neck=merged_arch_params.neck,
  41. heads=merged_arch_params.heads,
  42. num_classes=get_param(merged_arch_params, "num_classes", None),
  43. in_channels=in_channels,
  44. bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
  45. bn_eps=get_param(merged_arch_params, "bn_eps", None),
  46. inplace_act=get_param(merged_arch_params, "inplace_act", None),
  47. )
  48. @staticmethod
  49. def get_post_prediction_callback(conf: float, iou: float) -> PPYoloEPostPredictionCallback:
  50. return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
  51. @property
  52. def num_classes(self):
  53. return self.heads.num_classes
  54. @register_model(Models.YOLO_NAS_L)
  55. class YoloNAS_L(CustomizableDetector):
  56. def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
  57. default_arch_params = get_arch_params("yolo_nas_l_arch_params")
  58. merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
  59. merged_arch_params.override(**arch_params.to_dict())
  60. super().__init__(
  61. backbone=merged_arch_params.backbone,
  62. neck=merged_arch_params.neck,
  63. heads=merged_arch_params.heads,
  64. num_classes=get_param(merged_arch_params, "num_classes", None),
  65. in_channels=in_channels,
  66. bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
  67. bn_eps=get_param(merged_arch_params, "bn_eps", None),
  68. inplace_act=get_param(merged_arch_params, "inplace_act", None),
  69. )
  70. @staticmethod
  71. def get_post_prediction_callback(conf: float, iou: float) -> PPYoloEPostPredictionCallback:
  72. return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
  73. @property
  74. def num_classes(self):
  75. return self.heads.num_classes
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
  1. from functools import partial
  2. from typing import Type, List
  3. import torch
  4. from torch import nn, Tensor
  5. from super_gradients.common.registry import register_detection_module
  6. from super_gradients.modules import Residual, BaseDetectionModule
  7. from super_gradients.common.decorators.factory_decorator import resolve_param
  8. from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
  9. from super_gradients.modules import QARepVGGBlock, Conv
  10. from super_gradients.modules.utils import width_multiplier
  11. __all__ = ["YoloNASStage", "YoloNASUpStage", "YoloNASStem", "YoloNASDownStage", "YoloNASBottleneck"]
  12. class YoloNASBottleneck(nn.Module):
  13. """
  14. A bottleneck block for YoloNAS. Consists of two consecutive blocks and optional residual connection.
  15. """
  16. def __init__(
  17. self, input_channels: int, output_channels: int, block_type: Type[nn.Module], activation_type: Type[nn.Module], shortcut: bool, use_alpha: bool
  18. ):
  19. """
  20. Initialize the YoloNASBottleneck block
  21. :param input_channels: Number of input channels
  22. :param output_channels: Number of output channels
  23. :param block_type: Type of the convolutional block
  24. :param activation_type: Activation type for the convolutional block
  25. :param shortcut: If True, adds the residual connection from input to output.
  26. :param use_alpha: If True, adds the learnable alpha parameter (multiplier for the residual connection).
  27. """
  28. super().__init__()
  29. self.cv1 = block_type(input_channels, output_channels, activation_type=activation_type)
  30. self.cv2 = block_type(output_channels, output_channels, activation_type=activation_type)
  31. self.add = shortcut and input_channels == output_channels
  32. self.shortcut = Residual() if self.add else None
  33. if use_alpha:
  34. self.alpha = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
  35. else:
  36. self.alpha = 1.0
  37. def forward(self, x):
  38. return self.alpha * self.shortcut(x) + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  39. class SequentialWithIntermediates(nn.Sequential):
  40. """
  41. A Sequential module that can return all intermediate values as a list of Tensors
  42. """
  43. def __init__(self, output_intermediates: bool, *args):
  44. super(SequentialWithIntermediates, self).__init__(*args)
  45. self.output_intermediates = output_intermediates
  46. def forward(self, input: Tensor) -> List[Tensor]:
  47. if self.output_intermediates:
  48. output = [input]
  49. for module in self:
  50. output.append(module(output[-1]))
  51. return output
  52. # For uniformity, we return a list even if we don't output intermediates
  53. return [super(SequentialWithIntermediates, self).forward(input)]
  54. class YoloNASCSPLayer(nn.Module):
  55. """
  56. Cross-stage layer module for YoloNAS.
  57. """
  58. def __init__(
  59. self,
  60. in_channels: int,
  61. out_channels: int,
  62. num_bottlenecks: int,
  63. block_type: Type[nn.Module],
  64. activation_type: Type[nn.Module],
  65. shortcut: bool = True,
  66. use_alpha: bool = True,
  67. expansion: float = 0.5,
  68. hidden_channels: int = None,
  69. concat_intermediates: bool = False,
  70. ):
  71. """
  72. :param in_channels: Number of input channels.
  73. :param out_channels: Number of output channels.
  74. :param num_bottlenecks: Number of bottleneck blocks.
  75. :param block_type: Bottleneck block type.
  76. :param activation_type: Activation type for all blocks.
  77. :param shortcut: If True, adds the residual connection from input to output.
  78. :param use_alpha: If True, adds the learnable alpha parameter (multiplier for the residual connection).
  79. :param expansion: If hidden_channels is None, hidden_channels is set to in_channels * expansion.
  80. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks.
  81. :param concat_intermediates:
  82. """
  83. super(YoloNASCSPLayer, self).__init__()
  84. if hidden_channels is None:
  85. hidden_channels = int(out_channels * expansion)
  86. self.conv1 = Conv(in_channels, hidden_channels, 1, stride=1, activation_type=activation_type)
  87. self.conv2 = Conv(in_channels, hidden_channels, 1, stride=1, activation_type=activation_type)
  88. self.conv3 = Conv(hidden_channels * (2 + concat_intermediates * num_bottlenecks), out_channels, 1, stride=1, activation_type=activation_type)
  89. module_list = [YoloNASBottleneck(hidden_channels, hidden_channels, block_type, activation_type, shortcut, use_alpha) for _ in range(num_bottlenecks)]
  90. self.bottlenecks = SequentialWithIntermediates(concat_intermediates, *module_list)
  91. def forward(self, x: Tensor) -> Tensor:
  92. x_1 = self.conv1(x)
  93. x_1 = self.bottlenecks(x_1)
  94. x_2 = self.conv2(x)
  95. x = torch.cat((*x_1, x_2), dim=1)
  96. return self.conv3(x)
  97. @register_detection_module()
  98. class YoloNASStem(BaseDetectionModule):
  99. """
  100. Stem module for YoloNAS. Consists of a single QARepVGGBlock with stride of two.
  101. """
  102. def __init__(self, in_channels: int, out_channels: int):
  103. """
  104. Initialize the YoloNASStem module
  105. :param in_channels: Number of input channels
  106. :param out_channels: Number of output channels
  107. """
  108. super().__init__(in_channels)
  109. self._out_channels = out_channels
  110. self.conv = QARepVGGBlock(in_channels, out_channels, stride=2, use_residual_connection=False)
  111. @property
  112. def out_channels(self):
  113. return self._out_channels
  114. def forward(self, x: Tensor) -> Tensor:
  115. return self.conv(x)
  116. @register_detection_module()
  117. class YoloNASStage(BaseDetectionModule):
  118. """
  119. A single stage module for YoloNAS. It consists of a downsample block (QARepVGGBlock) followed by YoloNASCSPLayer.
  120. """
  121. @resolve_param("activation_type", ActivationsTypeFactory())
  122. def __init__(
  123. self,
  124. in_channels: int,
  125. out_channels: int,
  126. num_blocks: int,
  127. activation_type: Type[nn.Module],
  128. hidden_channels: int = None,
  129. concat_intermediates: bool = False,
  130. ):
  131. """
  132. Initialize the YoloNASStage module
  133. :param in_channels: Number of input channels
  134. :param out_channels: Number of output channels
  135. :param num_blocks: Number of bottleneck blocks in the YoloNASCSPLayer
  136. :param activation_type: Activation type for all blocks
  137. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks.
  138. :param concat_intermediates: If True, concatenates the intermediate values from the YoloNASCSPLayer.
  139. """
  140. super().__init__(in_channels)
  141. self._out_channels = out_channels
  142. self.downsample = QARepVGGBlock(in_channels, out_channels, stride=2, activation_type=activation_type, use_residual_connection=False)
  143. self.blocks = YoloNASCSPLayer(
  144. out_channels,
  145. out_channels,
  146. num_blocks,
  147. QARepVGGBlock,
  148. activation_type,
  149. True,
  150. hidden_channels=hidden_channels,
  151. concat_intermediates=concat_intermediates,
  152. )
  153. @property
  154. def out_channels(self):
  155. return self._out_channels
  156. def forward(self, x):
  157. return self.blocks(self.downsample(x))
  158. @register_detection_module()
  159. class YoloNASUpStage(BaseDetectionModule):
  160. """
  161. Upsampling stage for YoloNAS.
  162. """
  163. @resolve_param("activation_type", ActivationsTypeFactory())
  164. def __init__(
  165. self,
  166. in_channels: List[int],
  167. out_channels: int,
  168. width_mult: float,
  169. num_blocks: int,
  170. depth_mult: float,
  171. activation_type: Type[nn.Module],
  172. hidden_channels: int = None,
  173. concat_intermediates: bool = False,
  174. reduce_channels: bool = False,
  175. ):
  176. """
  177. Initialize the YoloNASUpStage module
  178. :param in_channels: Number of input channels
  179. :param out_channels: Number of output channels
  180. :param width_mult: Multiplier for the number of channels in the stage.
  181. :param num_blocks: Number of bottleneck blocks
  182. :param depth_mult: Multiplier for the number of blocks in the stage.
  183. :param activation_type: Activation type for all blocks
  184. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks
  185. :param concat_intermediates:
  186. :param reduce_channels:
  187. """
  188. super().__init__(in_channels)
  189. num_inputs = len(in_channels)
  190. if num_inputs == 2:
  191. in_channels, skip_in_channels = in_channels
  192. else:
  193. in_channels, skip_in_channels1, skip_in_channels2 = in_channels
  194. skip_in_channels = skip_in_channels1 + out_channels # skip2 downsample results in out_channels channels
  195. out_channels = width_multiplier(out_channels, width_mult, 8)
  196. num_blocks = max(round(num_blocks * depth_mult), 1) if num_blocks > 1 else num_blocks
  197. if num_inputs == 2:
  198. self.reduce_skip = Conv(skip_in_channels, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  199. else:
  200. self.reduce_skip1 = Conv(skip_in_channels1, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  201. self.reduce_skip2 = Conv(skip_in_channels2, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  202. self.conv = Conv(in_channels, out_channels, 1, 1, activation_type)
  203. self.upsample = nn.ConvTranspose2d(in_channels=out_channels, out_channels=out_channels, kernel_size=2, stride=2)
  204. if num_inputs == 3:
  205. self.downsample = Conv(out_channels if reduce_channels else skip_in_channels2, out_channels, kernel=3, stride=2, activation_type=activation_type)
  206. self.reduce_after_concat = Conv(num_inputs * out_channels, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  207. after_concat_channels = out_channels if reduce_channels else out_channels + skip_in_channels
  208. self.blocks = YoloNASCSPLayer(
  209. after_concat_channels,
  210. out_channels,
  211. num_blocks,
  212. QARepVGGBlock,
  213. activation_type,
  214. hidden_channels=hidden_channels,
  215. concat_intermediates=concat_intermediates,
  216. )
  217. self._out_channels = [out_channels, out_channels]
  218. @property
  219. def out_channels(self):
  220. return self._out_channels
  221. def forward(self, inputs):
  222. if len(inputs) == 2:
  223. x, skip_x = inputs
  224. skip_x = [self.reduce_skip(skip_x)]
  225. else:
  226. x, skip_x1, skip_x2 = inputs
  227. skip_x1, skip_x2 = self.reduce_skip1(skip_x1), self.reduce_skip2(skip_x2)
  228. skip_x = [skip_x1, self.downsample(skip_x2)]
  229. x_inter = self.conv(x)
  230. x = self.upsample(x_inter)
  231. x = torch.cat([x, *skip_x], 1)
  232. x = self.reduce_after_concat(x)
  233. x = self.blocks(x)
  234. return x_inter, x
  235. @register_detection_module()
  236. class YoloNASDownStage(BaseDetectionModule):
  237. @resolve_param("activation_type", ActivationsTypeFactory())
  238. def __init__(
  239. self,
  240. in_channels: List[int],
  241. out_channels: int,
  242. width_mult: float,
  243. num_blocks: int,
  244. depth_mult: float,
  245. activation_type: Type[nn.Module],
  246. hidden_channels: int = None,
  247. concat_intermediates: bool = False,
  248. ):
  249. """
  250. Initializes a YoloNASDownStage.
  251. :param in_channels: Number of input channels.
  252. :param out_channels: Number of output channels.
  253. :param width_mult: Multiplier for the number of channels in the stage.
  254. :param num_blocks: Number of blocks in the stage.
  255. :param depth_mult: Multiplier for the number of blocks in the stage.
  256. :param activation_type: Type of activation to use inside the blocks.
  257. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks.
  258. :param concat_intermediates:
  259. """
  260. super().__init__(in_channels)
  261. in_channels, skip_in_channels = in_channels
  262. out_channels = width_multiplier(out_channels, width_mult, 8)
  263. num_blocks = max(round(num_blocks * depth_mult), 1) if num_blocks > 1 else num_blocks
  264. self.conv = Conv(in_channels, out_channels // 2, 3, 2, activation_type)
  265. after_concat_channels = out_channels // 2 + skip_in_channels
  266. self.blocks = YoloNASCSPLayer(
  267. in_channels=after_concat_channels,
  268. out_channels=out_channels,
  269. num_bottlenecks=num_blocks,
  270. block_type=partial(Conv, kernel=3, stride=1),
  271. activation_type=activation_type,
  272. hidden_channels=hidden_channels,
  273. concat_intermediates=concat_intermediates,
  274. )
  275. self._out_channels = out_channels
  276. @property
  277. def out_channels(self):
  278. return self._out_channels
  279. def forward(self, inputs):
  280. x, skip_x = inputs
  281. x = self.conv(x)
  282. x = torch.cat([x, skip_x], 1)
  283. x = self.blocks(x)
  284. return x
Discard
@@ -52,7 +52,7 @@ class Pipeline(ABC):
     def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], class_names: List[str], device: Optional[str] = None):
         super().__init__()
         self.device = device or next(model.parameters()).device
-        self.model = model.to(device)
+        self.model = model.to(self.device)
         self.class_names = class_names
 
         if isinstance(image_processor, list):
@@ -265,7 +265,12 @@ class DetectionPipeline(Pipeline):
     def _combine_image_prediction_to_images(
         self, images_predictions: Iterable[ImageDetectionPrediction], n_images: Optional[int] = None
     ) -> ImagesDetectionPrediction:
-        images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
+        if n_images is not None and n_images == 1:
+            # Do not show tqdm progress bar if there is only one image
+            pass
+        else:
+            images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
+
         return ImagesDetectionPrediction(_images_prediction_lst=images_predictions)
 
     def _combine_image_prediction_to_video(
Discard
@@ -59,6 +59,10 @@ MODEL_URLS = {
     "ppyoloe_m_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_ppyoloe_m.pth",
     "ppyoloe_l_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_pp_yoloe_l_best_model_21uffbb8.pth",  # 0.4948
     "ppyoloe_x_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_pp_yoloe_x_best_model_z03if91o.pth",  # 0.5115
+    #
+    "yolo_nas_s_coco": "https://deci-pretrained-models.s3.amazonaws.com/yolo_nas/yolo_nas_s_coco2017.pth",
+    "yolo_nas_m_coco": "https://deci-pretrained-models.s3.amazonaws.com/yolo_nas/yolo_nas_m_coco2017.pth",
+    "yolo_nas_l_coco": "https://deci-pretrained-models.s3.amazonaws.com/yolo_nas/yolo_nas_l_coco2017.pth",
 }
 
 PRETRAINED_NUM_CLASSES = {
Discard
@@ -305,8 +305,8 @@ def default_ppyoloe_coco_processing_params() -> dict:
     return params
 
 
-def default_deciyolo_coco_processing_params() -> dict:
-    """Processing parameters commonly used for training DeciYolo on COCO dataset.
+def default_yolo_nas_coco_processing_params() -> dict:
+    """Processing parameters commonly used for training YoloNAS on COCO dataset.
     TODO: remove once we load it from the checkpoint
     """
 
@@ -322,8 +322,8 @@ def default_deciyolo_coco_processing_params() -> dict:
     params = dict(
         class_names=COCO_DETECTION_CLASSES_LIST,
         image_processor=image_processor,
-        iou=0.65,
-        conf=0.5,
+        iou=0.7,
+        conf=0.25,
     )
     return params
 
@@ -337,6 +337,6 @@ def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -
             return default_yolox_coco_processing_params()
         elif "ppyoloe" in model_name:
             return default_ppyoloe_coco_processing_params()
-        elif "deciyolo" in model_name:
-            return default_deciyolo_coco_processing_params()
+        elif "yolo_nas" in model_name:
+            return default_yolo_nas_coco_processing_params()
     return dict()
Discard
@@ -291,11 +291,21 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
     :param pretrained_weights: name for the pretrianed weights (i.e imagenet)
     :return: None
     """
+    from super_gradients.common.object_names import Models
+
     model_url_key = architecture + "_" + str(pretrained_weights)
     if model_url_key not in MODEL_URLS.keys():
         raise MissingPretrainedWeightsException(model_url_key)
 
     url = MODEL_URLS[model_url_key]
+
+    if architecture in {Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L}:
+        logger.info(
+            "License Notification: YOLO-NAS pre-trained weights are subjected to the specific license terms and conditions detailed in \n"
+            "https://github.com/Deci-AI/super-gradients/LICENSE.YOLONAS.md. \n"
+            "By downloading the pre-trained weight files you agree to comply with these terms."
+        )
+
     unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace("/", "_").replace(" ", "_")
     map_location = torch.device("cpu")
     pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
Discard
@@ -1,7 +1,7 @@
 import sys
 import unittest
 
-from tests.integration_tests import EMAIntegrationTest, LRTest, PoseEstimationDatasetIntegrationTest
+from tests.integration_tests import EMAIntegrationTest, LRTest, PoseEstimationDatasetIntegrationTest, YoloNASIntegrationTest
 
 
 class CoreIntegrationTestSuiteRunner:
@@ -19,6 +19,7 @@ class CoreIntegrationTestSuiteRunner:
         self.integration_tests_suite.addTest(self.test_loader.loadTestsFromModule(EMAIntegrationTest))
         self.integration_tests_suite.addTest(self.test_loader.loadTestsFromModule(LRTest))
         self.integration_tests_suite.addTest(self.test_loader.loadTestsFromModule(PoseEstimationDatasetIntegrationTest))
+        self.integration_tests_suite.addTest(self.test_loader.loadTestsFromModule(YoloNASIntegrationTest))
 
 
 if __name__ == "__main__":
Discard
@@ -34,6 +34,7 @@ from tests.unit_tests.pose_estimation_dataset_test import TestPoseEstimationData
 from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
+from tests.unit_tests.replace_head_test import ReplaceHeadUnitTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInitializedObjectsTest
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
@@ -132,6 +133,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationMetrics))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationDataset))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LoadCheckpointTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ReplaceHeadUnitTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PreprocessingUnitTest))
 
     def _add_modules_to_end_to_end_tests_suite(self):
Discard
@@ -3,5 +3,6 @@
 from tests.integration_tests.ema_train_integration_test import EMAIntegrationTest
 from tests.integration_tests.lr_test import LRTest
 from tests.integration_tests.pose_estimation_dataset_test import PoseEstimationDatasetIntegrationTest
+from tests.integration_tests.yolo_nas_integration_test import YoloNASIntegrationTest
 
-__all__ = ["EMAIntegrationTest", "LRTest", "PoseEstimationDatasetIntegrationTest"]
+__all__ = ["EMAIntegrationTest", "LRTest", "PoseEstimationDatasetIntegrationTest", "YoloNASIntegrationTest"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  1. import unittest
  2. from super_gradients.training import models
  3. from super_gradients.training.dataloaders import coco2017_val_yolo_nas
  4. from super_gradients.training import Trainer
  5. from super_gradients.training.metrics import DetectionMetrics
  6. from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
  7. class YoloNASIntegrationTest(unittest.TestCase):
  8. def test_yolo_nas_s_coco(self):
  9. trainer = Trainer("test_yolo_nas_s")
  10. model = models.get("yolo_nas_s", num_classes=80, pretrained_weights="coco")
  11. dl = coco2017_val_yolo_nas()
  12. metric = DetectionMetrics(
  13. normalize_targets=True,
  14. post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65),
  15. num_cls=80,
  16. )
  17. metric_values = trainer.test(model=model, test_loader=dl, test_metrics_list=[metric])
  18. self.assertAlmostEqual(metric_values[metric.map_str], 0.475, delta=0.001)
  19. def test_yolo_nas_m_coco(self):
  20. trainer = Trainer("test_yolo_nas_m")
  21. model = models.get("yolo_nas_m", num_classes=80, pretrained_weights="coco")
  22. dl = coco2017_val_yolo_nas()
  23. metric = DetectionMetrics(
  24. normalize_targets=True,
  25. post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65),
  26. num_cls=80,
  27. )
  28. metric_values = trainer.test(model=model, test_loader=dl, test_metrics_list=[metric])
  29. self.assertAlmostEqual(metric_values[metric.map_str], 0.5155, delta=0.001)
  30. def test_yolo_nas_l_coco(self):
  31. trainer = Trainer("test_yolo_nas_l")
  32. model = models.get("yolo_nas_l", num_classes=80, pretrained_weights="coco")
  33. dl = coco2017_val_yolo_nas()
  34. metric = DetectionMetrics(
  35. normalize_targets=True,
  36. post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65),
  37. num_cls=80,
  38. )
  39. metric_values = trainer.test(model=model, test_loader=dl, test_metrics_list=[metric])
  40. self.assertAlmostEqual(metric_values[metric.map_str], 0.5222, delta=0.001)
  41. if __name__ == "__main__":
  42. unittest.main()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  1. import os
  2. import shutil
  3. import unittest
  4. import torch
  5. import super_gradients
  6. from super_gradients.common.object_names import Models
  7. from super_gradients.training import models
  8. class ReplaceHeadUnitTest(unittest.TestCase):
  9. def setUp(self) -> None:
  10. self.device = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
  11. super_gradients.init_trainer()
  12. def test_ppyolo_replace_head(self):
  13. input = torch.randn(1, 3, 640, 640).to(self.device)
  14. for model in [Models.PP_YOLOE_S, Models.PP_YOLOE_M, Models.PP_YOLOE_L, Models.PP_YOLOE_X]:
  15. model = models.get(model, pretrained_weights="coco").to(self.device).eval()
  16. model.replace_head(new_num_classes=100)
  17. (_, pred_scores), _ = model.forward(input)
  18. self.assertEqual(pred_scores.size(2), 100)
  19. def test_yolo_nas_replace_head(self):
  20. input = torch.randn(1, 3, 640, 640).to(self.device)
  21. for model in [Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L]:
  22. model = models.get(model, pretrained_weights="coco").to(self.device).eval()
  23. model.replace_head(new_num_classes=100)
  24. (_, pred_scores), _ = model.forward(input)
  25. self.assertEqual(pred_scores.size(2), 100)
  26. def tearDown(self) -> None:
  27. if os.path.exists("~/.cache/torch/hub/"):
  28. shutil.rmtree("~/.cache/torch/hub/")
  29. if __name__ == "__main__":
  30. unittest.main()
Discard