Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

08_train.py 6.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
  1. import os
  2. import sys
  3. import argparse
  4. import pickle
  5. import joblib
  6. import numpy as np
  7. import pandas as pd
  8. from sklearn.linear_model import LogisticRegression
  9. from catboost import CatBoostRegressor, CatBoostClassifier
  10. from lendingclub import config, utils
  11. import j_utils.munging as mg
  12. def prepare_data(model_n, data, proc=None, ds_type='train'):
  13. '''
  14. returns the processed data for a model, which could be different between
  15. model types e.g. can handle categoricals or not. additionally returns
  16. a tuple of anything necessary to process valid/test data in the same manner
  17. ds_type must be 'train', 'valid', or 'test'
  18. '''
  19. assert ds_type in ['train', 'valid', 'test'], print('ds_type invalid')
  20. if model_n in ['baseline', 'A', 'B', 'C', 'D', 'E', 'F', 'G']:
  21. return data, None
  22. # elif model_n == 'logistic_regr':
  23. else:
  24. if ds_type == 'train':
  25. temp = mg.train_proc(data)
  26. procced = temp[0]
  27. return procced, temp[1:]
  28. elif ds_type in ['test', 'valid']:
  29. assert proc, print('must pass data processing artifacts')
  30. temp = mg.val_test_proc(data, *proc)
  31. return temp
  32. def train_model(model_n, X_train, y_train, X_valid=None, y_valid=None):
  33. '''
  34. Fit model and return model
  35. '''
  36. if model_n in ['baseline', 'A', 'B', 'C', 'D', 'E', 'F', 'G']:
  37. return 42
  38. elif model_n == 'logistic_regr':
  39. lr_model = LogisticRegression(class_weight='balanced')
  40. lr_model.fit(X_train, y_train)
  41. return lr_model
  42. elif model_n == 'catboost_regr':
  43. # basic params for regressor
  44. params = {
  45. 'iterations': 100000,
  46. 'one_hot_max_size': 10,
  47. # 'learning_rate': 0.01,
  48. # 'has_time': True,
  49. 'depth': 7,
  50. 'l2_leaf_reg': .5,
  51. 'random_strength': 5,
  52. 'loss_function': 'RMSE',
  53. 'eval_metric': 'RMSE',#'Recall',
  54. 'random_seed': 42,
  55. 'use_best_model': True,
  56. 'task_type': 'GPU',
  57. # 'boosting_type': 'Ordered',
  58. # 'loss_function': 'Log',
  59. 'custom_metric': ['MAE', 'RMSE', 'MAPE', 'Quantile'],
  60. 'od_type': 'Iter',
  61. 'od_wait': 300,
  62. }
  63. obj_cols = X_train.select_dtypes(['object', 'datetime']).columns
  64. categorical_features_indices = [X_train.columns.get_loc(col) for col in obj_cols]
  65. catboost_regr = CatBoostRegressor(**params)
  66. catboost_regr.fit(X_train, y_train, cat_features=categorical_features_indices,
  67. eval_set=(X_valid, y_valid,), logging_level='Verbose', plot=True) #
  68. return catboost_regr
  69. elif model_n == 'catboost_clf':
  70. # basic params
  71. params = {
  72. 'iterations': 100000,
  73. 'one_hot_max_size': 10,
  74. 'learning_rate': 0.01,
  75. 'depth': 7,
  76. 'l2_leaf_reg': .5,
  77. 'random_strength': 5,
  78. # 'has_time': True,
  79. 'eval_metric': 'Logloss',#'Recall',
  80. 'random_seed': 42,
  81. 'logging_level': 'Silent',
  82. 'use_best_model': True,
  83. 'task_type': 'GPU',
  84. # 'boosting_type': 'Ordered',
  85. # 'loss_function': 'Log',
  86. 'custom_metric': ['F1', 'Precision', 'Recall', 'Accuracy', 'AUC'],
  87. 'od_type': 'Iter',
  88. 'od_wait': 300,
  89. }
  90. # get categorical feature indices for catboost
  91. obj_cols = X_train.select_dtypes(['object', 'datetime']).columns
  92. categorical_features_indices = [X_train.columns.get_loc(col) for col in obj_cols]
  93. catboost_clf = CatBoostClassifier(**params)
  94. catboost_clf.fit(X_train, y_train, cat_features=categorical_features_indices,
  95. eval_set=(X_valid, y_valid,), logging_level='Verbose', plot=True) #
  96. return catboost_clf
  97. def export_models(m, model_n):
  98. if model_n in ['baseline', 'A', 'B', 'C', 'D', 'E', 'F', 'G']:
  99. with open(os.path.join(config.modeling_dir, '{0}_model.pkl'.format(model_n)), 'wb') as file:
  100. pickle.dump(m, file)
  101. elif model_n == 'logistic_regr':
  102. joblib.dump(m,os.path.join(config.modeling_dir, '{0}_model.pkl'.format(model_n)))
  103. elif model_n in ['catboost_clf', 'catboost_regr']:
  104. m.save_model(os.path.join(config.modeling_dir, '{0}_model.cb'.format(model_n)))
  105. def export_data_processing(proc_arti, model_n):
  106. if model_n in ['baseline', 'A', 'B', 'C', 'D', 'E', 'F', 'G']:
  107. with open(os.path.join(config.modeling_dir, '{0}_model_proc_arti.pkl'.format(model_n)), 'wb') as file:
  108. pickle.dump(proc_arti, file)
  109. elif model_n in ['logistic_regr', 'catboost_clf', 'catboost_regr']:
  110. joblib.dump(proc_arti, os.path.join(config.modeling_dir, '{0}_model_proc_arti.pkl'.format(model_n)))
  111. parser = argparse.ArgumentParser()
  112. parser.add_argument('--model', '-m', help='specify model(s) to train')
  113. if not len(sys.argv) > 1:
  114. models = ['logistic_regr'] # , 'A', 'B', 'C', 'D', 'E', 'F', 'G'
  115. args = parser.parse_args()
  116. if args.model:
  117. models = args.model.split()
  118. # models = ['logistic_regr']
  119. if not os.path.isdir(config.modeling_dir):
  120. os.makedirs(config.modeling_dir)
  121. tr_val_base_data, tr_val_eval_data, _ = utils.load_dataset(ds_type='train')
  122. # ensure ordering is correct for time series split
  123. tr_val_base_data, tr_val_eval_data = mg.sort_train_eval(tr_val_base_data, tr_val_eval_data, 'id', 'issue_d')
  124. for model_n in models:
  125. print('training {0}'.format(model_n))
  126. # do 3 steps of TS cross validation, with valid size at 5% (20 splits)
  127. tscv = mg.time_series_data_split(tr_val_eval_data, 'issue_d', 20, 1)
  128. for tr_idx, val_idx in tscv:
  129. # split out validation from train_data
  130. if model_n in ['logistic_regr', 'catboost_clf']:
  131. y_train = tr_val_eval_data.loc[tr_idx, 'target_loose']
  132. y_valid = tr_val_eval_data.loc[val_idx, 'target_loose']
  133. else:
  134. y_train = tr_val_eval_data.loc[tr_idx, '0.07']
  135. y_valid = tr_val_eval_data.loc[val_idx, '0.07']
  136. X_train = tr_val_base_data.loc[tr_idx]
  137. X_valid = tr_val_base_data.loc[val_idx]
  138. X_train, proc_arti = prepare_data(model_n, X_train, ds_type='train')
  139. X_valid = prepare_data(model_n, X_valid, proc = proc_arti, ds_type='valid')
  140. m = train_model(model_n, X_train, y_train, X_valid, y_valid)
  141. #save stuff
  142. export_models(m, model_n)
  143. export_data_processing(proc_arti, model_n)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...