Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#732 Feature/sg 564 explicit how to adapt dataset to SG

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-564-dataset_adapter
@@ -199,12 +199,22 @@ import torch
 class MyCustomDataset(torch.utils.data.Dataset):
 class MyCustomDataset(torch.utils.data.Dataset):
     def __init__(self, train: bool, image_size: int):
     def __init__(self, train: bool, image_size: int):
         ...
         ...
+
+    def __getitem__(self, item):
+        ...
+        return inputs, targets # Or inputs, targets, additional_batch_items
 ```
 ```
 
 
+#### A. `__getitem__`
+You need to make sure that the `__getitem__` method of your dataset complies with the following format:
+   - `inputs = batch_items[0]` : model input - The type might depend on the model you are using.
+   - `targets = batch_items[1]` : Target that will be used to compute loss/metrics - The type might depend on the function you are using.
+   - [OPTIONAL] `additional_batch_items = batch_items[2]` : Dict made of any additional item that you might want to use.
+
+#### B. Train with your dataset
 For coded training launch, we can instantiate it, then use it in the same way as the first code snippet to create
 For coded training launch, we can instantiate it, then use it in the same way as the first code snippet to create
 the data loaders and call train():
 the data loaders and call train():
 
 
-
 ```python
 ```python
 
 
 from my_dataset import MyCustomDataset
 from my_dataset import MyCustomDataset
@@ -220,7 +230,6 @@ model = ...
 train_params = {...}
 train_params = {...}
 
 
 trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)
 trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)
-   
 ```
 ```
 
 
 ### Using Custom Datasets in SG- Training with Configuration Files
 ### Using Custom Datasets in SG- Training with Configuration Files
Discard
@@ -39,11 +39,13 @@ class UnsupportedBatchItemsFormat(ValueError):
         message -- explanation of the error
         message -- explanation of the error
     """
     """
 
 
-    def __init__(self):
+    def __init__(self, batch_items: tuple):
         self.message = (
         self.message = (
-            "Batch items returned by the data loader expected format: \n"
-            "1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len("
-            "batch_items) = 2 \n"
-            "2. tuple: (inputs, targets, additional_batch_items)"
+            f"The data loader is expected to return 2 to 3 items, but got {len(batch_items)} instead.\n"
+            "Items expected:\n"
+            "   - inputs = batch_items[0] # model input - The type might depend on the model you are using.\n"
+            "   - targets = batch_items[1] # Target that will be used to compute loss/metrics - The type might depend on the function you are using.\n"
+            "   - [OPTIONAL] additional_batch_items = batch_items[2] # Dict made of any additional item that you might want to use.\n"
+            "To fix this, please change the implementation of your dataset __getitem__ method, so that it would return the items defined above.\n"
         )
         )
         super().__init__(self.message)
         super().__init__(self.message)
Discard
@@ -1,5 +1,5 @@
 from pathlib import Path
 from pathlib import Path
-from typing import Tuple, Type, Optional
+from typing import Tuple, Type, Optional, Union
 
 
 import hydra
 import hydra
 import torch
 import torch
@@ -79,7 +79,7 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
 
 
 def instantiate_model(
 def instantiate_model(
     model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None, download_required_code: bool = True
     model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None, download_required_code: bool = True
-) -> torch.nn.Module:
+) -> Union[SgModule, torch.nn.Module]:
     """
     """
     Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
     Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
         module manipulation (i.e head replacement).
         module manipulation (i.e head replacement).
@@ -147,7 +147,7 @@ def get(
     load_backbone: bool = False,
     load_backbone: bool = False,
     download_required_code: bool = True,
     download_required_code: bool = True,
     checkpoint_num_classes: int = None,
     checkpoint_num_classes: int = None,
-) -> SgModule:
+) -> Union[SgModule, torch.nn.Module]:
     """
     """
     :param model_name:          Defines the model's architecture from models/ALL_ARCHITECTURES
     :param model_name:          Defines the model's architecture from models/ALL_ARCHITECTURES
     :param arch_params:         Architecture hyper parameters. e.g.: block, num_blocks, etc.
     :param arch_params:         Architecture hyper parameters. e.g.: block, num_blocks, etc.
Discard
@@ -399,7 +399,7 @@ def unpack_batch_items(batch_items: Union[tuple, torch.Tensor]):
         inputs, target, additional_batch_items = batch_items
         inputs, target, additional_batch_items = batch_items
 
 
     else:
     else:
-        raise UnsupportedBatchItemsFormat()
+        raise UnsupportedBatchItemsFormat(batch_items)
 
 
     return inputs, target, additional_batch_items
     return inputs, target, additional_batch_items
 
 
Discard