Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

example.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  1. import mlflow
  2. import mlflow.sklearn
  3. from sklearn.datasets import load_iris
  4. from sklearn.ensemble import RandomForestClassifier
  5. from sklearn.model_selection import train_test_split
  6. from urllib.parse import urlparse
  7. from sklearn.metrics import accuracy_score, precision_score, recall_score
  8. # Define model hyperparameters
  9. n_estimators = 25
  10. max_depth = 5
  11. min_samples_split = 8
  12. # Load dataset
  13. iris = load_iris()
  14. X = iris.data
  15. y = iris.target
  16. # Split the dataset
  17. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  18. # Create a RandomForestClassifier model with specified hyperparameters
  19. model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, min_samples_split=min_samples_split, random_state=42)
  20. # Train the model
  21. model.fit(X_train, y_train)
  22. # Make predictions
  23. y_pred = model.predict(X_test)
  24. # Calculate metrics
  25. accuracy = accuracy_score(y_test, y_pred)
  26. precision = precision_score(y_test, y_pred, average='macro')
  27. recall = recall_score(y_test, y_pred, average='macro')
  28. # Log parameters and metrics to MLflow
  29. with mlflow.start_run():
  30. # Log hyperparameters
  31. mlflow.log_param("n_estimators", n_estimators)
  32. mlflow.log_param("max_depth", max_depth)
  33. mlflow.log_param("min_samples_split", min_samples_split)
  34. # Log metrics
  35. mlflow.log_metric("accuracy", accuracy)
  36. mlflow.log_metric("precision", precision)
  37. mlflow.log_metric("recall", recall)
  38. # Log the model
  39. remote_server_uri ="https://dagshub.com/santoshraiii/mlops_demo.mlflow"
  40. mlflow.set_tracking_uri(remote_server_uri)
  41. tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
  42. if tracking_url_type_store != "file":
  43. mlflow.sklearn.log_model(
  44. model,"model")
  45. else:
  46. mlflow.sklearn.log_model(model, "random_forest_model")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...