Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_scores.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  1. import pytest
  2. import pandas as pd
  3. import omegaconf as oc
  4. from molbart.utils.scores import (
  5. TanimotoSimilarityScore,
  6. TopKAccuracyScore,
  7. FractionInvalidScore,
  8. FractionUniqueScore,
  9. ScoreCollection,
  10. )
  11. from molbart.utils import trainer_utils
  12. def test_default_inference_scoring():
  13. config = oc.OmegaConf.load("molbart/config/inference_score.yaml")
  14. score_config = config.get("scorers")
  15. scorers = trainer_utils.instantiate_scorers(score_config)
  16. scorer_names = set(scorers.names())
  17. expected = set(["top_k_accuracy", "fraction_invalid", "top1_tanimoto_similarity", "fraction_unique"])
  18. assert scorer_names.issubset(expected) and expected.issubset(scorer_names)
  19. sampled_smiles = [["C!O", "CCO", "CCO"], ["c1ccccc1", "c1cccc1", "c1ccccc1"]]
  20. target_smiles = ["CCO", "c1ccccc1"]
  21. metrics_scores = scorers.score(sampled_smiles, target_smiles)
  22. assert round(metrics_scores["fraction_invalid"], 4) == 0.3333
  23. assert round(metrics_scores["fraction_unique"], 4) == 0.3333
  24. assert metrics_scores["top1_tanimoto_similarity"] == 1.0
  25. assert metrics_scores["accuracy_top_1"] == 0.5
  26. assert metrics_scores["accuracy_top_3"] == 1.0
  27. @pytest.mark.parametrize(
  28. ("sampled_smiles", "target_smiles", "expected_score"),
  29. [
  30. ([["CCO", "CCO", "CCO"], ["c1cc!ccc1", "c1cccc1", "c1ccccc1"]], ["CCO", "c1ccccc1"], 0.3333),
  31. ([["CCO", "CCO", "CCO"], ["c1ccccc1", "c1cccc1", "c1ccccc1"]], ["CCO", "c1ccccc1"], 0.1667),
  32. ([["CCO", "C!O", "CCO"], ["c1ccccc1", "c1cccc1", "c1ccccc1"]], ["CCO", "c1ccccc1"], 0.3333),
  33. ([["CCO", "CCO", "CCO"], ["c1ccccc1", "c1ccccc1", "c1ccccc1"]], ["CCO", "c1ccccc1"], 0.0),
  34. ],
  35. )
  36. def test_fraction_invalid(sampled_smiles, target_smiles, expected_score):
  37. scorer = ScoreCollection()
  38. scorer.load(FractionInvalidScore())
  39. score = scorer.score(sampled_smiles, target_smiles)["fraction_invalid"]
  40. assert round(score, 4) == expected_score
  41. @pytest.mark.parametrize(
  42. ("sampled_smiles", "target_smiles", "expected_score"),
  43. [
  44. ([["CCO", "CCO", "CCO"], ["c1cc!ccc1", "c1cccc1", "c1ccccc1"]], ["CCO", "c1ccccc1"], 0.3333),
  45. ([["CCO", "CCO", "CCO"], ["c1ccccc1", "c1cccc1", "c1ccccc1"]], ["CCO", "c1ccccc1"], 0.3333),
  46. ([["CCO", "C!O", "COO"], ["c1ccccc1", "c1cccc1", "c1cc(Br)ccc1"]], ["CCO", "c1ccccc1"], 0.6667),
  47. ],
  48. )
  49. def test_fraction_unique(sampled_smiles, target_smiles, expected_score):
  50. scorer = ScoreCollection()
  51. scorer.load(FractionUniqueScore())
  52. score = scorer.score(sampled_smiles, target_smiles)["fraction_unique"]
  53. print(round(score, 4))
  54. assert round(score, 4) == expected_score
  55. def test_accuracy_similarity(round_trip_converted_prediction_data):
  56. scorer = ScoreCollection()
  57. scorer.load(TanimotoSimilarityScore(statistics="all"))
  58. scorer.load(TopKAccuracyScore())
  59. sampled_smiles = round_trip_converted_prediction_data["sampled_smiles"]
  60. target_smiles = round_trip_converted_prediction_data["target_smiles"]
  61. metrics_out = []
  62. for sampled_batch, target_batch in zip(sampled_smiles, target_smiles):
  63. metrics = scorer.score(
  64. sampled_batch,
  65. target_batch,
  66. )
  67. metrics = {key: [val] for key, val in metrics.items()}
  68. metrics_out.append(pd.DataFrame(metrics))
  69. metrics_df = pd.concat(metrics_out, axis=0)
  70. assert all(sim == 1.0 for sim in metrics_df["top1_tanimoto_similarity"].values[0][0])
  71. assert round(metrics_df["accuracy_top_1"].values[0], 4) == 0.6667
  72. assert round(metrics_df["accuracy_top_3"].values[0], 4) == 0.6667
  73. assert round(metrics_df["accuracy_top_5"].values[0], 4) == 0.6667
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...