Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#471 Feature/sg 415 add registry callback and transform

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-415-add_registry_callback
@@ -24,6 +24,8 @@ Recipes support out of the box every model, metric or loss that is implemented i
    * model: `from super_gradients.training.utils.registry import register_model`
    * model: `from super_gradients.training.utils.registry import register_model`
    * loss: `from super_gradients.training.utils.registry import register_loss`
    * loss: `from super_gradients.training.utils.registry import register_loss`
    * dataloader: `from super_gradients.training.utils.registry import register_dataloader`
    * dataloader: `from super_gradients.training.utils.registry import register_dataloader`
+   * callback: `from super_gradients.training.utils.registry import register_callback`
+   * transform: `from super_gradients.training.utils.registry import register_transform`
 3. Apply it on your object.
 3. Apply it on your object.
    * The decorator takes an optional `name: str` argument. If not specified, the decorated class name will be registered.
    * The decorator takes an optional `name: str` argument. If not specified, the decorated class name will be registered.
 
 
Discard
@@ -6,6 +6,8 @@ from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.metrics.all_metrics import METRICS
 from super_gradients.training.metrics.all_metrics import METRICS
 from super_gradients.training.losses.all_losses import LOSSES
 from super_gradients.training.losses.all_losses import LOSSES
 from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
 from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
+from super_gradients.training.utils.callbacks.all_callbacks import CALLBACKS
+from super_gradients.training.transforms.all_transforms import TRANSFORMS
 
 
 
 
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
@@ -45,3 +47,5 @@ register_detection_module = create_register_decorator(registry=ALL_DETECTION_MOD
 register_metric = create_register_decorator(registry=METRICS)
 register_metric = create_register_decorator(registry=METRICS)
 register_loss = create_register_decorator(registry=LOSSES)
 register_loss = create_register_decorator(registry=LOSSES)
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
+register_callback = create_register_decorator(registry=CALLBACKS)
+register_transform = create_register_decorator(registry=TRANSFORMS)
Discard