Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

st_evaluate_single_model.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  1. import json
  2. import dvc.api
  3. import pandas as pd
  4. import streamlit as st
  5. from scripts.params import EVALUATION_DIR
  6. from st_scripts.st_utils import st_model_selectbox, MODELS_PARAMETERS, REPO
  7. with open("./st_scripts/vega_graphs/confusion_matrix.json") as file:
  8. VEGA_CONFUSION_MATRIX = json.load(file)
  9. @st.cache
  10. def load_predictions(model_rev: str) -> pd.DataFrame:
  11. with dvc.api.open(EVALUATION_DIR / "predictions.csv", rev=model_rev) as file:
  12. return pd.read_csv(file)
  13. def st_evaluate_single_model():
  14. st.markdown("### Explore Performance on the Test Set")
  15. selected_model_rev = st_model_selectbox()
  16. threshold = st.sidebar.slider("Choose model threshold", 0.0, 1.0, value=0.5)
  17. model_parameters = MODELS_PARAMETERS[selected_model_rev]
  18. model_commit = REPO.commit(selected_model_rev)
  19. st.write("Commit information:", model_commit)
  20. st.json({
  21. "message": model_commit.message,
  22. "committed_datetime": str(model_commit.committed_datetime),
  23. "committer": str(model_commit.committer),
  24. })
  25. st.text("Model parameters:")
  26. st.json(model_parameters)
  27. st.markdown("## Metrics")
  28. predictions = (
  29. load_predictions(model_rev=selected_model_rev)
  30. .assign(predicted_label=lambda df: (
  31. pd.Series("cats", index=df.index).where(df.prediction < threshold, other="dogs")
  32. ))
  33. )
  34. accuracy = (predictions.true_label == predictions.predicted_label).mean()
  35. st.write("Accuracy (%):", round(100 * accuracy, 2))
  36. st.vega_lite_chart(predictions, VEGA_CONFUSION_MATRIX["spec"])
  37. st.markdown("## Images")
  38. images_selector_columns = st.columns(5)
  39. with images_selector_columns[2]:
  40. st.write("True label")
  41. st.write("Predicted label")
  42. with images_selector_columns[3]:
  43. show_true_cats_images = st.checkbox(label="cats", key="true_cats_images", value=False)
  44. show_predicted_cats_images = st.checkbox(label="cats", key="predicted_cats_images", value=True)
  45. with images_selector_columns[4]:
  46. show_true_dogs_images = st.checkbox(label="dogs", key="true_dogs_images", value=True)
  47. show_predicted_dogs_images = st.checkbox(label="dogs", key="predicted_cats_images", value=False)
  48. selected_true_labels = []
  49. if show_true_cats_images: selected_true_labels.append("cats")
  50. if show_true_dogs_images: selected_true_labels.append("dogs")
  51. selected_predicted_labels = []
  52. if show_predicted_cats_images: selected_predicted_labels.append("cats")
  53. if show_predicted_dogs_images: selected_predicted_labels.append("dogs")
  54. selected_predictions = predictions.loc[
  55. lambda df: df.true_label.isin(selected_true_labels)
  56. ].loc[
  57. lambda df: df.predicted_label.isin(selected_predicted_labels)
  58. ]
  59. with images_selector_columns[0]:
  60. st.write("Selected images:", len(selected_predictions))
  61. images_columns = st.columns(4)
  62. for idx, (_, row) in enumerate(selected_predictions.iterrows()):
  63. images_columns[idx % 4].image(
  64. row["image_path"],
  65. caption=f"true={row['true_label']}, predicted={row['predicted_label']}, pred={row['prediction']:.3f}",
  66. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...