Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

augmentation_grid.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  1. from typing import Dict, List, Optional
  2. import numpy as np
  3. def make_augmentation_params_grid(
  4. p_x_min: float, p_x_max: float, p_x_step: float,
  5. p_e_min: float, p_e_max: float, p_e_step: float,
  6. grid_type: str,
  7. ) -> List[Dict[str, Optional[float]]]:
  8. all_pairs = [
  9. (p_x, p_e)
  10. for p_x in np.arange(p_x_min, p_x_max + p_x_step, p_x_step)
  11. for p_e in np.arange(p_e_min, p_e_max + p_e_step, p_e_step)
  12. ]
  13. if grid_type == "ONLY_SAME":
  14. return [
  15. {"p_x_1": p_x, "p_e_1": p_e, "p_x_2": None, "p_e_2": None}
  16. for p_x, p_e in all_pairs
  17. ]
  18. else:
  19. assert grid_type == "ALL"
  20. # Build all parameter combinations
  21. grid = set()
  22. for p1 in all_pairs:
  23. for p2 in all_pairs:
  24. # Filter out symmetric pairs
  25. if (p1, p2) in grid or (p2, p1) in grid:
  26. continue
  27. grid.add((p1, p2))
  28. return [
  29. {"p_x_1": p_x_1, "p_e_1": p_e_1, "p_x_2": p_x_2, "p_e_2": p_e_2}
  30. for (p_x_1, p_e_1), (p_x_2, p_e_2) in sorted(grid)
  31. ]
  32. def is_same(augmentation_parameters: Dict[str, Optional[float]]) -> bool:
  33. return (
  34. augmentation_parameters["p_x_1"] == augmentation_parameters["p_x_2"]
  35. and
  36. augmentation_parameters["p_e_1"] == augmentation_parameters["p_e_2"]
  37. ) or (
  38. augmentation_parameters["p_x_2"] is None
  39. and
  40. augmentation_parameters["p_e_2"] is None
  41. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...