Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

metrics.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  1. from scipy.special import softmax, expit as sigmoid
  2. from sklearn.metrics import f1_score
  3. def ktiv_male_metrics(eval_preds, padding_idx=-100):
  4. logits, labels = eval_preds
  5. # shape of logits: (n_samples, n_chars, n_labels=3)
  6. # shape of labels: (n_samples, n_chars)
  7. # reshape to 2d to treat each character as a sample: (N, n_labels=15)
  8. n_labels = logits.shape[-1]
  9. logits = logits.reshape(-1, n_labels)
  10. labels = labels.reshape(-1)
  11. mask = (labels != padding_idx)
  12. logits = logits[mask]
  13. labels = labels[mask]
  14. probs = softmax(logits, axis=-1)
  15. preds = probs.argmax(axis=-1)
  16. accuracy = (labels == preds).mean()
  17. macro_f1 = f1_score(labels, preds, average='macro')
  18. return {
  19. 'accuracy': accuracy,
  20. 'macro_f1': macro_f1
  21. }
  22. def unikud_metrics(eval_preds, prob_threshold=0.5, padding_idx=-100):
  23. logits, labels = eval_preds
  24. # shapes of each: (n_samples, n_chars, n_labels=15)
  25. # reshape to 2d to treat each character as a sample: (N, n_labels=15)
  26. n_labels = logits.shape[-1]
  27. logits = logits.reshape(-1, n_labels)
  28. labels = labels.reshape(-1, n_labels)
  29. mask = (labels != padding_idx).all(axis=-1)
  30. logits = logits[mask]
  31. labels = labels[mask]
  32. probs = sigmoid(logits)
  33. preds = (probs >= prob_threshold).astype(int)
  34. accuracy = (labels == preds).all(axis=-1).mean()
  35. macro_f1 = f1_score(labels, preds, average='macro')
  36. return {
  37. 'accuracy': accuracy,
  38. 'macro_f1': macro_f1
  39. }
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...