Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

hydra_utils.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  1. import os
  2. from pathlib import Path
  3. from typing import List
  4. import pkg_resources
  5. import hydra
  6. from hydra import initialize_config_dir, compose
  7. from hydra.core.global_hydra import GlobalHydra
  8. from omegaconf import OmegaConf, open_dict, DictConfig
  9. from super_gradients.training.utils.checkpoint_utils import get_checkpoints_dir_path
  10. def load_experiment_cfg(experiment_name: str, ckpt_root_dir: str = None) -> DictConfig:
  11. """
  12. Load the hydra config associated to a specific experiment.
  13. Background Information: every time an experiment is launched based on a recipe, all the hydra config params are stored in a hidden folder ".hydra".
  14. This hidden folder is used here to recreate the exact same config as the one that was used to launch the experiment (Also include hydra overrides).
  15. The motivation is to be able to resume or evaluate an experiment with the exact same config as the one that was used when the experiment was
  16. initially started, regardless of any change that might have been introduced to the recipe, and also while using the same overrides that were used
  17. for that experiment.
  18. :param experiment_name: Name of the experiment to resume
  19. :param ckpt_root_dir: Directory including the checkpoints
  20. :return: The config that was used for that experiment
  21. """
  22. if not experiment_name:
  23. raise ValueError(f"experiment_name should be non empty string but got :{experiment_name}")
  24. checkpoints_dir_path = Path(get_checkpoints_dir_path(experiment_name, ckpt_root_dir))
  25. if not checkpoints_dir_path.exists():
  26. raise FileNotFoundError(f"Impossible to find checkpoint dir ({checkpoints_dir_path})")
  27. resume_dir = Path(checkpoints_dir_path) / ".hydra"
  28. if not resume_dir.exists():
  29. raise FileNotFoundError(f"The checkpoint directory {checkpoints_dir_path} does not include .hydra artifacts to resume the experiment.")
  30. # Load overrides that were used in previous run
  31. overrides_cfg = list(OmegaConf.load(resume_dir / "overrides.yaml"))
  32. GlobalHydra.instance().clear()
  33. with initialize_config_dir(config_dir=normalize_path(str(resume_dir)), version_base="1.2"):
  34. cfg = compose(config_name="config.yaml", overrides=overrides_cfg)
  35. return cfg
  36. def add_params_to_cfg(cfg: DictConfig, params: List[str]):
  37. """Add parameters to an existing config
  38. :param cfg: OmegaConf config
  39. :param params: List of parameters to add, in dotlist format (i.e. ["training_hyperparams.resume=True"])"""
  40. new_cfg = OmegaConf.from_dotlist(params)
  41. with open_dict(cfg): # This is required to add new fields to existing config
  42. cfg.merge_with(new_cfg)
  43. def normalize_path(path: str) -> str:
  44. """Normalize the directory of file path. Replace the Windows-style (\\) path separators with unix ones (/).
  45. This is necessary when running on Windows since Hydra compose fails to find a configuration file is the config
  46. directory contains backward slash symbol.
  47. :param path: Input path string
  48. :return: Output path string with all \\ symbols replaces with /.
  49. """
  50. return path.replace("\\", "/")
  51. def load_arch_params(config_name: str) -> DictConfig:
  52. """
  53. :param config_name: name of a yaml with arch parameters
  54. """
  55. GlobalHydra.instance().clear()
  56. sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
  57. dataset_config = os.path.join("arch_params", config_name)
  58. with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir), version_base="1.2"):
  59. # config is relative to a module
  60. return hydra.utils.instantiate(compose(config_name=normalize_path(dataset_config)).arch_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...