Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  1. import numpy as np
  2. import pandas as pd
  3. import seaborn as sns
  4. from sklearn.metrics import confusion_matrix
  5. from zntrack import Node, zn
  6. class EvaluateModel(Node):
  7. # dependencies
  8. ml_model = zn.deps()
  9. test_data = zn.deps()
  10. # metrics
  11. metrics = zn.metrics()
  12. confusion_matrix = zn.plots(template="confusion", x="predicted", y="actual")
  13. def run(self):
  14. """Primary Node Method"""
  15. loss, accuracy = self.ml_model.evaluate(
  16. self.test_data.features, self.test_data.labels
  17. )
  18. self.metrics = {"loss": loss, "accuracy": accuracy}
  19. prediction = self.ml_model.predict(self.test_data.features)
  20. self.confusion_matrix = pd.DataFrame(
  21. [
  22. {"actual": np.argmax(true), "predicted": np.argmax(false)}
  23. for true, false in zip(self.test_data.labels, prediction)
  24. ]
  25. )
  26. def plot_confusion_matrix(self):
  27. cf_mat = confusion_matrix(
  28. self.confusion_matrix["actual"], self.confusion_matrix["predicted"]
  29. )
  30. sns.heatmap(cf_mat)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...