Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

example.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  1. # The data set used in this example is from http://archive.ics.uci.edu/ml/datasets/Wine+Quality
  2. # P. Cortez, A. Cerdeira, F. Almeida, T. Matos and J. Reis.
  3. # Modeling wine preferences by data mining from physicochemical properties. In Decision Support Systems, Elsevier, 47(4):547-553, 2009.
  4. import logging
  5. import sys
  6. import os
  7. import warnings
  8. from urllib.parse import urlparse
  9. from dotenv import load_dotenv
  10. load_dotenv()
  11. import numpy as np
  12. import pandas as pd
  13. from sklearn.linear_model import ElasticNet
  14. from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
  15. from sklearn.model_selection import train_test_split
  16. import mlflow
  17. import mlflow.sklearn
  18. from mlflow.models import infer_signature
  19. logging.basicConfig(level=logging.WARN)
  20. logger = logging.getLogger(__name__)
  21. def eval_metrics(actual, pred):
  22. rmse = np.sqrt(mean_squared_error(actual, pred))
  23. mae = mean_absolute_error(actual, pred)
  24. r2 = r2_score(actual, pred)
  25. return rmse, mae, r2
  26. if __name__ == "__main__":
  27. warnings.filterwarnings("ignore")
  28. np.random.seed(40)
  29. # Read the wine-quality csv file from the URL
  30. csv_url = (
  31. "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/winequality-red.csv"
  32. )
  33. try:
  34. data = pd.read_csv(csv_url, sep=";")
  35. except Exception as e:
  36. logger.exception(
  37. "Unable to download training & test CSV, check your internet connection. Error: %s", e
  38. )
  39. # Split the data into training and test sets. (0.75, 0.25) split.
  40. train, test = train_test_split(data)
  41. # The predicted column is "quality" which is a scalar from [3, 9]
  42. train_x = train.drop(["quality"], axis=1)
  43. test_x = test.drop(["quality"], axis=1)
  44. train_y = train[["quality"]]
  45. test_y = test[["quality"]]
  46. alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
  47. l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5
  48. with mlflow.start_run():
  49. lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
  50. lr.fit(train_x, train_y)
  51. predicted_qualities = lr.predict(test_x)
  52. (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)
  53. print(f"Elasticnet model (alpha={alpha:f}, l1_ratio={l1_ratio:f}):")
  54. print(f" RMSE: {rmse}")
  55. print(f" MAE: {mae}")
  56. print(f" R2: {r2}")
  57. mlflow.log_param("alpha", alpha)
  58. mlflow.log_param("l1_ratio", l1_ratio)
  59. mlflow.log_metric("rmse", rmse)
  60. mlflow.log_metric("r2", r2)
  61. mlflow.log_metric("mae", mae)
  62. predictions = lr.predict(train_x)
  63. signature = infer_signature(train_x, predictions)
  64. # For remote server - dagshub
  65. remote_server_uri = os.getenv("MLFLOW_TRACKING_URI")
  66. mlflow.set_tracking_uri(remote_server_uri)
  67. tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
  68. # Model registry does not work with file store
  69. if tracking_url_type_store != "file":
  70. # Register the model
  71. # There are other ways to use the Model Registry, which depends on the use case,
  72. # please refer to the doc for more information:
  73. # https://mlflow.org/docs/latest/model-registry.html#api-workflow
  74. mlflow.sklearn.log_model(
  75. lr, "model", registered_model_name="ElasticnetWineModel", signature=signature
  76. )
  77. else:
  78. mlflow.sklearn.log_model(lr, "model", signature=signature)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...