Are you sure you want to delete this access key?
To train a model, it is necessary to configure 4 main components. These components are aggregated into a single "main" recipe .yaml file that inherits the aforementioned dataset, architecture, training and checkpoint params.
Recipes support out of the box every model, metric or loss that is implemented in SuperGradients, but you can easily extend this to any custom object that you need by "registering it".
Notes:
In your python script
torchmetrics.Metric
torch.nn.Module
torch.nn.modules.loss._Loss
from super_gradients.training.utils.registry import register_metric
from super_gradients.training.utils.registry import register_model
from super_gradients.training.utils.registry import register_loss
from super_gradients.training.utils.registry import register_dataloader
from super_gradients.training.utils.registry import register_callback
from super_gradients.training.utils.registry import register_transform
name: str
argument. If not specified, the decorated class name will be registered.In your recipe (.yaml)
SuperGradients works with torchmetrics.Metric . To write your own metric you need to implement update() and compute() methods.
In order to work on DDP you also need to define states using add_state(). States are attributes to be reduced, and broadcasted among the different ranks in compute() when training in distributed setting. An example of state would be the number of correct predictions, which will be summed across the different processes, broadcasted to all of them before computing the metric value. You can see an example below.
Feel free to check torchmetrics documentation for more information on how to implement your own metric.
main.py
import omegaconf
import hydra
import torch
import torchmetrics
from super_gradients import Trainer, init_trainer
from super_gradients.common.registry.registry import register_metric
@register_metric() # Will be registered as "CustomTop5"
class CustomTop5(torchmetrics.Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
batch_size = target.size(0)
# Get the top k predictions
_, pred = preds.topk(5, 1, True, True)
pred = pred.t()
# Count the number of correct predictions only for the highest 5
correct = pred.eq(target.view(1, -1).expand_as(pred))
correct5 = correct[:5].reshape(-1).float().sum(0)
self.correct += correct5
self.total += batch_size
def compute(self):
return self.correct.float() / self.total
@hydra.main(config_path="recipes")
def main(cfg: omegaconf.DictConfig) -> None:
Trainer.train_from_config(cfg)
init_trainer()
main()
recipes/training_hyperparams/my_training_hyperparams.yaml
... # Other training hyperparams
train_metrics_list:
- CustomTop5
valid_metrics_list:
- CustomTop5
Launch the script
python main.py --config-name=my_recipe.yaml
import omegaconf
import hydra
import torch
import torch.nn as nn
import torch.nn.functional as F
from super_gradients import Trainer, init_trainer
from super_gradients.common.registry import register_model
@register_model('my_conv_net') # will be registered as "my_conv_net"
class MyConvNet(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
@hydra.main(config_path="recipes")
def main(cfg: omegaconf.DictConfig) -> None:
Trainer.train_from_config(cfg)
init_trainer()
main()
recipes/my_recipe.yaml
... # Other recipe params
architecture: my_conv_net
Launch the script
python main.py --config-name=my_recipe.yaml
main.py
import omegaconf
import hydra
import torch
from super_gradients import Trainer, init_trainer
from super_gradients.common.registry.registry import register_loss
@register_loss("custom_rsquared_loss")
class CustomRSquaredLoss(torch.nn.modules.loss._Loss): # The Loss needs to inherit from torch _Loss class.
def forward(self, output, target):
criterion_mse = torch.nn.MSELoss()
return 1 - criterion_mse(output, target).item() / torch.var(target).item()
@hydra.main(config_path="recipes")
def main(cfg: omegaconf.DictConfig) -> None:
Trainer.train_from_config(cfg)
init_trainer()
main()
recipes/training_hyperparams/my_training_hyperparams.yaml
... # Other training hyperparams
loss: custom_rsquared_loss
Launch the script
python main.py --config-name=my_recipe.yaml
Press p or to see the previous file or, n or to see the next file
Browsing data directories saved to S3 is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with AWS S3!
Are you sure you want to delete this access key?
Browsing data directories saved to Google Cloud Storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with Google Cloud Storage!
Are you sure you want to delete this access key?
Browsing data directories saved to Azure Cloud Storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with Azure Cloud Storage!
Are you sure you want to delete this access key?
Browsing data directories saved to S3 compatible storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with your S3 compatible storage!
Are you sure you want to delete this access key?