Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ElasticNet.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  1. import os
  2. import warnings
  3. import sys
  4. # import dagshub
  5. import pandas as pd
  6. import numpy as np
  7. from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
  8. from sklearn.model_selection import train_test_split
  9. from sklearn.linear_model import ElasticNet
  10. from urllib.parse import urlparse
  11. import mlflow
  12. import mlflow.sklearn
  13. from mlflow.models import infer_signature
  14. from itertools import product
  15. import logging
  16. from sys import version_info
  17. import cloudpickle
  18. import mlflow.pyfunc
  19. # os.environ['MLFLOW_TRACKING_USERNAME'] = 'bindusara007'
  20. # os.environ['MLFLOW_TRACKING_PASSWORD'] = 'e8b66976bf29ade7fc08e9ebb6c3b0bda784edd4'
  21. logging.basicConfig(level=logging.WARN)
  22. logger = logging.getLogger(__name__)
  23. #dagshub.init(repo_owner='bindusara007', repo_name='E2E', mlflow=True)
  24. #mlflow.set_tracking_uri('https://dagshub.com/bindusara007/E2E.mlflow')
  25. # mlflow.set_tracking_uri('http://ec2-3-88-101-45.compute-1.amazonaws.com:5000/') # For aws instance
  26. def eval_metrics(actual, pred):
  27. rmse = np.sqrt(mean_squared_error(actual, pred))
  28. mae = mean_absolute_error(actual, pred)
  29. r2 = r2_score(actual, pred)
  30. return rmse, mae, r2
  31. if __name__ == "__main__":
  32. warnings.filterwarnings("ignore")
  33. np.random.seed(40)
  34. # Read the wine-quality csv file from the URL
  35. csv_url = (
  36. "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/winequality-red.csv"
  37. )
  38. try:
  39. data = pd.read_csv(csv_url, sep=";")
  40. except Exception as e:
  41. logger.exception(
  42. "Unable to download training & test CSV, check your internet connection. Error: %s", e
  43. )
  44. # Split the data into training and test sets. (0.75, 0.25) split.
  45. train, test = train_test_split(data)
  46. # The predicted column is "quality" which is a scalar from [3, 9]
  47. train_x = train.drop(["quality"], axis=1)
  48. test_x = test.drop(["quality"], axis=1)
  49. train_y = train[["quality"]]
  50. test_y = test[["quality"]]
  51. alpha = [0.2, 0.6, 0.8]
  52. l1_ratio = [0.5, 0.7]
  53. mlflow.autolog()
  54. with mlflow.start_run(run_name="Hyperparameter_Tuning") as parent_run:
  55. mlflow.log_param("tuning_method", "grid_search")
  56. param_combinations = list(product(alpha, l1_ratio))
  57. for alpha, l1_ratio in param_combinations:
  58. with mlflow.start_run(nested=True) as child_run:
  59. lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
  60. lr.fit(train_x, train_y)
  61. predicted_qualities = lr.predict(test_x)
  62. (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)
  63. print("Elasticnet model (alpha={:f}, l1_ratio={:f}):".format(alpha, l1_ratio))
  64. print(" RMSE: %s" % rmse)
  65. print(" MAE: %s" % mae)
  66. print(" R2: %s" % r2)
  67. mlflow.log_param("alpha", alpha)
  68. mlflow.log_param("l1_ratio", l1_ratio)
  69. mlflow.log_metric("rmse", rmse)
  70. mlflow.log_metric("r2", r2)
  71. mlflow.log_metric("mae", mae)
  72. tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
  73. # Model registry does not work with file store
  74. if tracking_url_type_store != "file":
  75. # Register the model
  76. # There are other ways to use the Model Registry, which depends on the use case,
  77. # please refer to the doc for more information:
  78. # https://mlflow.org/docs/latest/model-registry.html#api-workflow
  79. mlflow.sklearn.log_model(lr, "model", registered_model_name="ElasticnetWineModel")
  80. else:
  81. mlflow.sklearn.log_model(lr, "model")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...