Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. from sklearn.model_selection import train_test_split
  2. import matplotlib.pyplot as plt
  3. import pandas as pd
  4. import json
  5. import dagshub
  6. from xgboost import XGBClassifier
  7. # Load data
  8. df = pd.read_csv("heartDisease.csv")
  9. y = df['num']
  10. x = df.drop(['num'], axis=1)
  11. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42)
  12. def main():
  13. model = XGBClassifier().fit(x_train, y_train)
  14. train_accuracy = model.score(x_train, y_train)
  15. test_accuracy = model.score(x_test, y_test)
  16. with open('metrics.json','w') as of:
  17. json.dump({ "accuracy": test_accuracy}, of)
  18. of.close()
  19. with dagshub.dagshub_logger() as logger:
  20. logger.log_hyperparams(model_class=type(model).__name__)
  21. logger.log_hyperparams({'model': model.get_params()})
  22. logger.log_metrics({f'accuracy':round(test_accuracy,3)})
  23. importance = model.feature_importances_
  24. for i,v in enumerate(importance):
  25. print('Feature: %0d, Score: %.5f' % (i,v))
  26. # plot feature importance
  27. plt.barh(x.columns, importance)
  28. plt.savefig('feature_importance.png')
  29. if __name__ == '__main__':
  30. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...