Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_export_recipe.py 934 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  1. import tempfile
  2. import unittest
  3. import os
  4. import hydra
  5. from hydra import initialize_config_dir, compose
  6. from super_gradients.common.environment.cfg_utils import export_recipe
  7. class TestExportRecipe(unittest.TestCase):
  8. def test_export_recipe(self):
  9. with tempfile.TemporaryDirectory() as td:
  10. save_path = os.path.join(td, "cifar10_resnet_complete.yaml")
  11. # Define the command to run your script
  12. export_recipe(config_name="cifar10_resnet", save_path=save_path)
  13. # Check if the output file was created
  14. self.assertTrue(os.path.exists(save_path))
  15. with initialize_config_dir(config_dir=td, version_base="1.2"):
  16. cfg = compose(config_name="cifar10_resnet_complete.yaml")
  17. cfg = hydra.utils.instantiate(cfg)
  18. self.assertEqual(cfg.training_hyperparams.max_epochs, 250)
  19. if __name__ == "__main__":
  20. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...