Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  1. import pandas as pd
  2. import numpy as np
  3. from catboost import CatBoostRegressor
  4. from sklearn.metrics import mean_absolute_error
  5. import pickle
  6. from dvc.api import params_show
  7. import json
  8. def read_file(path):
  9. with open(path, 'rb') as fp:
  10. f = pickle.load(fp)
  11. return f
  12. def train_catboost(train_data):
  13. X_train,y_train = train_data
  14. categorical_indices = np.where(X_train.dtypes=='object')[0]
  15. categorical_indices = categorical_indices.tolist()
  16. cb = CatBoostRegressor(n_estimators=200,
  17. loss_function='RMSE',
  18. learning_rate=0.1,
  19. depth=8, task_type='CPU',
  20. random_state=1,
  21. verbose=False)
  22. cb.fit(X_train, y_train,cat_features=categorical_indices)
  23. pickle.dump(cb,open("model/catboost.pickle",'wb'))
  24. return cb
  25. def eval_catboost(train_data,test_data,cb):
  26. X_train, y_train = train_data
  27. X_test, y_test = test_data
  28. y_pred_train = cb.predict(X_train)
  29. y_pred = cb.predict(X_test)
  30. print('train MAE: ',mean_absolute_error(y_train,y_pred_train))
  31. print('test MAE: ',mean_absolute_error(y_test,y_pred))
  32. diz_eval = {'train_mae':mean_absolute_error(y_train,y_pred_train),
  33. 'test_mae':mean_absolute_error(y_test,y_pred)}
  34. with open('evaluation/metrics.json', "w") as fd:
  35. json.dump(diz_eval,fd,indent=4,)
  36. if __name__ == "__main__":
  37. PATHS = params_show()['PATHS']
  38. train_df = pd.read_csv(PATHS['train'])
  39. test_df = pd.read_csv(PATHS['test'])
  40. X_train, y_train = train_df[params_show()['cb_features']['feature_names']],train_df[params_show()['cb_features']['target']]
  41. X_test, y_test = test_df[params_show()['cb_features']['feature_names']],test_df[params_show()['cb_features']['target']]
  42. cb = train_catboost((X_train,y_train))
  43. eval_catboost((X_train, y_train), (X_test, y_test),cb)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...