Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

st_utils.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  1. from contextlib import contextmanager
  2. from pathlib import Path
  3. from typing import Union, Optional
  4. import os
  5. import json
  6. import git
  7. import streamlit as st
  8. import yaml
  9. from scripts.params import EVALUATION_DIR
  10. ROOT_DIR = Path(os.path.dirname(os.path.realpath(__file__))).parent
  11. REPO = git.Repo(str(ROOT_DIR))
  12. FIRST_COMMIT = list(REPO.iter_commits())[-1]
  13. # First commit a full training pipeline was run
  14. FIRST_PIPELINE_COMMIT = "6f96e3ffb98b0ba833937c510e93a6cdd3555f05"
  15. #%% Utils for git, dvc, streamlit
  16. @contextmanager
  17. def git_open(path: Union[str, Path], rev: str):
  18. commit = REPO.commit(rev)
  19. # Hack to get the full blob data stream: compute diff with initial commit
  20. diff = commit.diff(FIRST_COMMIT, str(path))[0]
  21. yield diff.a_blob.data_stream
  22. #%% Retrieve commits for trained model
  23. MODELS_COMMITS = list(REPO.iter_commits(
  24. rev=f"...{FIRST_PIPELINE_COMMIT}",
  25. paths="dvc.lock",
  26. ))
  27. #%% Utils for model parameters
  28. def _read_train_params(rev: str) -> dict:
  29. with git_open(ROOT_DIR / "dvc.lock", rev=rev) as file:
  30. dvc_lock = yaml.safe_load(file)
  31. return dvc_lock["stages"]["train"]["params"]["params.yaml"]
  32. MODELS_PARAMETERS = {
  33. commit.hexsha: _read_train_params(rev=commit.hexsha)
  34. for commit in MODELS_COMMITS
  35. }
  36. #%%
  37. def _read_model_evaluation_metrics(model_rev: str) -> dict:
  38. with git_open(EVALUATION_DIR / "metrics.json", rev=model_rev) as file:
  39. return json.load(file)
  40. MODELS_EVALUATION_METRICS = {
  41. commit.hexsha: _read_model_evaluation_metrics(model_rev=commit.hexsha)
  42. for commit in MODELS_COMMITS
  43. }
  44. #%%
  45. def get_model_backbone(model_rev: str) -> Optional[str]:
  46. model_parameters = MODELS_PARAMETERS[model_rev]
  47. try:
  48. return model_parameters["model"]["backbone"].split(".")[-1]
  49. except KeyError:
  50. pass
  51. def _display_model(hexsha: str) -> str:
  52. commit = REPO.commit(hexsha)
  53. backbone = get_model_backbone(hexsha) or "-"
  54. return f"{commit.message} / {backbone} / {commit.committed_datetime}"
  55. def st_model_multiselect():
  56. return st.sidebar.multiselect(
  57. "Choose your model(s)",
  58. [commit.hexsha for commit in MODELS_COMMITS],
  59. format_func=_display_model,
  60. )
  61. def st_model_selectbox():
  62. return st.sidebar.selectbox(
  63. "Choose your model",
  64. [commit.hexsha for commit in MODELS_COMMITS],
  65. format_func=_display_model,
  66. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...