Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_min_samples_single_node.py 681 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from torch.utils.data import TensorDataset
  5. from super_gradients.training import dataloaders
  6. class TestMinSamplesSingleNode(unittest.TestCase):
  7. def test_min_samples(self):
  8. dataset_size = 64
  9. image_size = 32
  10. images = torch.Tensor(np.zeros((dataset_size, 3, image_size, image_size)))
  11. ground_truth = torch.LongTensor(np.zeros((dataset_size)))
  12. dataloader = dataloaders.get(dataset=TensorDataset(images, ground_truth), dataloader_params={"batch_size": 4, "min_samples": 80, "drop_last": True})
  13. self.assertEqual(len(dataloader), 80 / 4)
  14. if __name__ == "__main__":
  15. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...