Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  1. import pandas as pd
  2. import numpy as np
  3. import mlflow
  4. import yaml
  5. import os
  6. from os.path import join
  7. from transformers import pipeline
  8. from packages.evaluate_utilities import get_data_sample, predictions_evaluation
  9. # Get all the required yaml files
  10. params_process = yaml.safe_load(open("params.yaml"))["preprocess"]
  11. params_eval = yaml.safe_load(open("params.yaml"))["evaluate"]
  12. mlflow_config = yaml.safe_load(open("credentials.yaml"))["mlflow_config"]
  13. combined_file_path = join("data", "processed", params_process['final_file_name']+params_process['final_ext'])
  14. combined_files = pd.read_csv(combined_file_path)
  15. # Model definition
  16. zsmlc_classifier = pipeline("zero-shot-classification",
  17. model='joeddav/xlm-roberta-large-xnli')
  18. # Getting Mlflow credentials
  19. MLFLOW_TRACKING_URI= mlflow_config['MLFLOW_TRACKING_URI']
  20. MLFLOW_TRACKING_USERNAME = mlflow_config['MLFLOW_TRACKING_USERNAME']
  21. MLFLOW_TRACKING_PASSWORD = mlflow_config['MLFLOW_TRACKING_PASSWORD']
  22. os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME
  23. os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD
  24. mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
  25. if __name__ == "__main__":
  26. mlflow.set_experiment("Multi-linguage-classification")
  27. with mlflow.start_run():
  28. for language in combined_files['language'].unique():
  29. lang_sample_data = get_data_sample(combined_files, language)
  30. pred_eval = predictions_evaluation(lang_sample_data, zsmlc_classifier)
  31. # Log different metrics
  32. mlflow.log_metric(language+"_accuracy", pred_eval['accuracy'])
  33. mlflow.log_metric(language+"_f1_score", pred_eval['f1_score'])
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...