Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

training.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  1. """
  2. Training Gradient boosting model with Grid Search CV
  3. """
  4. from sklearn.model_selection import GridSearchCV
  5. from sklearn.ensemble import GradientBoostingRegressor
  6. from sklearn.model_selection import KFold
  7. import numpy as np
  8. import pickle
  9. import yaml
  10. params = yaml.safe_load(open("params.yaml"))["training"]
  11. n_estimators = params["n_est"]
  12. max_depth = params["m_depth"]
  13. learning_rate = params["lr"]
  14. min_samples_split = params["min_split"]
  15. min_samples_leaf = params["min_leaf"]
  16. param_grid = {'n_estimators': n_estimators,
  17. 'max_depth': max_depth,
  18. 'learning_rate': learning_rate,
  19. 'min_samples_split': min_samples_split,
  20. 'min_samples_leaf': min_samples_leaf}
  21. def training():
  22. print("Training GBRT model with Grid Search")
  23. print("Loading scaled features and labels")
  24. x_train = np.load("data/processed_data/x_train.npy")
  25. y_train = np.load("data/processed_data/y_train.npy")
  26. scaling_model = pickle.load(open("data/scaling_model.pkl", "rb"))
  27. x_tr_scale = scaling_model.transform(x_train)
  28. print("done")
  29. model = GradientBoostingRegressor(n_estimators=n_estimators,
  30. min_samples_split=min_samples_split,
  31. max_depth=max_depth,
  32. learning_rate=learning_rate,
  33. min_samples_leaf=min_samples_leaf)
  34. model.fit(x_tr_scale, y_train)
  35. # print("Cross Validation Started")
  36. # kfold = KFold(n_splits=10)
  37. # grid_search = GridSearchCV(model, param_grid, cv=kfold, scoring = 'neg_mean_squared_error')
  38. # grid_search.fit(x_tr_scale, y_train)
  39. # print("done")
  40. with open("data/gbrt_model.pkl", "wb") as x_f:
  41. pickle.dump(model, x_f)
  42. if __name__ == '__main__':
  43. training()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...