Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#821 Feature/sg 735 deci yolo qs

Merged
Ghost merged 1 commits into Deci-AI:feature/SG-736_deci_yolo_rf100 from deci-ai:feature/SG-735_deci_yolo_qs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  1. import importlib
  2. import sys
  3. from omegaconf import OmegaConf
  4. from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
  5. def get_cls(cls_path: str):
  6. """
  7. A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.
  8. usage:
  9. class_of_optimizer: ${class:torch.optim.Adam}
  10. """
  11. module = ".".join(cls_path.split(".")[:-1])
  12. name = cls_path.split(".")[-1]
  13. importlib.import_module(module)
  14. return getattr(sys.modules[module], name)
  15. def hydra_output_dir_resolver(ckpt_root_dir: str, experiment_name: str) -> str:
  16. return get_checkpoints_dir_path(experiment_name=experiment_name, ckpt_root_dir=ckpt_root_dir)
  17. def register_hydra_resolvers():
  18. """Register all the hydra resolvers required for the super-gradients recipes."""
  19. from super_gradients.training.datasets.detection_datasets.roboflow.utils import get_dataset_num_classes
  20. OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True)
  21. OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True)
  22. OmegaConf.register_new_resolver("add", lambda *args: sum(args), replace=True)
  23. OmegaConf.register_new_resolver("cond", lambda boolean, x, y: x if boolean else y, replace=True)
  24. OmegaConf.register_new_resolver("getitem", lambda container, key: container[key], replace=True) # get item from a container (list, dict...)
  25. OmegaConf.register_new_resolver("first", lambda lst: lst[0], replace=True) # get the first item from a list
  26. OmegaConf.register_new_resolver("last", lambda lst: lst[-1], replace=True) # get the last item from a list
  27. OmegaConf.register_new_resolver("roboflow_dataset_num_classes", get_dataset_num_classes, replace=True)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file