Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

models.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. from dataclasses import *
  2. from functools import cached_property
  3. import lightgbm as lgb
  4. from lightgbm import Booster
  5. from sklearn.metrics import *
  6. from yspecies.utils import *
  7. @dataclass
  8. class Metrics:
  9. '''
  10. Class to store metrics
  11. '''
  12. @staticmethod
  13. def combine(metrics: List['Metrics']) -> pd.DataFrame:
  14. mts = pd.DataFrame(np.zeros([len(metrics), 3]), columns=["R^2", "MSE", "MAE"])
  15. for i, m in enumerate(metrics):
  16. mts.iloc[i] = m.to_numpy
  17. return mts
  18. @staticmethod
  19. def calculate(prediction, ground_truth) -> 'Metrics':
  20. return Metrics(
  21. r2_score(ground_truth, prediction),
  22. mean_squared_error(ground_truth, prediction),
  23. mean_absolute_error(ground_truth, prediction))
  24. R2: float
  25. MSE: float
  26. MAE: float
  27. @cached_property
  28. def to_numpy(self):
  29. return np.array([self.R2, self.MSE, self.MAE])
  30. @dataclass
  31. class ModelFactory:
  32. parameters: Dict = field(default_factory=lambda: {
  33. 'boosting_type': 'gbdt',
  34. 'objective': 'regression',
  35. 'metric': {'l2', 'l1'},
  36. 'max_leaves': 20,
  37. 'max_depth': 3,
  38. 'learning_rate': 0.07,
  39. 'feature_fraction': 0.8,
  40. 'bagging_fraction': 1,
  41. 'min_data_in_leaf': 6,
  42. 'lambda_l1': 0.9,
  43. 'lambda_l2': 0.9,
  44. "verbose": -1
  45. })
  46. def regression_model(self, X_train, X_test, y_train, y_test, categorical=None, num_boost_round:int = 500, params: dict = None) -> Booster:
  47. '''
  48. trains a regression model
  49. :param X_train:
  50. :param X_test:
  51. :param y_train:
  52. :param y_test:
  53. :param categorical:
  54. :param params:
  55. :return:
  56. '''
  57. parameters = self.parameters if params is None else params
  58. cat = categorical if len(categorical) >0 else "auto"
  59. lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=cat)
  60. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  61. evals_result = {}
  62. gbm = lgb.train(parameters,
  63. lgb_train,
  64. num_boost_round=num_boost_round,
  65. valid_sets=lgb_eval,
  66. evals_result=evals_result,
  67. verbose_eval=num_boost_round)
  68. return gbm
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...