Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

random_erase_test.py 607 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  1. import torch
  2. import unittest
  3. from super_gradients.training.datasets.data_augmentation import RandomErase
  4. class RandomEraseTest(unittest.TestCase):
  5. def test_random_erase(self):
  6. dummy_input = torch.randn(1, 3, 32, 32)
  7. one_erase = RandomErase(probability=0, value="1.")
  8. self.assertEqual(one_erase.p, 0)
  9. self.assertEqual(one_erase.value, 1.0)
  10. one_erase(dummy_input)
  11. rndm_erase = RandomErase(probability=0, value="random")
  12. self.assertEqual(rndm_erase.value, "random")
  13. rndm_erase(dummy_input)
  14. if __name__ == "__main__":
  15. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...