Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

svm_train.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  1. import pandas as pd
  2. import numpy as np
  3. import pickle
  4. from datetime import datetime
  5. from sklearn.metrics import accuracy_score, f1_score
  6. from sklearn.feature_extraction.text import TfidfVectorizer
  7. from sklearn.model_selection import KFold, GridSearchCV
  8. from sklearn.svm import SVC
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.ensemble import BaggingClassifier
  11. from sklearn.multiclass import OneVsRestClassifier
  12. import dagshub
  13. def load_code_blocks(DATASET_PATH, CODE_COLUMN):
  14. df = pd.read_csv(DATASET_PATH, encoding='utf-8', comment='#', sep=',')#, quoting=csv.QUOTE_NONE, error_bad_lines=False)#, sep=','
  15. print(df.head())
  16. code_blocks = df[CODE_COLUMN]
  17. # test_size = 0.1
  18. # test_rows = round(df.shape[0]*test_size)
  19. # train_rows = df.shape[0] - test_rows
  20. # train_code_blocks = df[CODE_COLUMN][0:test_rows]
  21. # test_code_blocks = df[CODE_COLUMN][train_rows:]
  22. return df, code_blocks
  23. def tfidf_fit_transform(code_blocks, params, TFIDF_DIR):
  24. vectorizer = TfidfVectorizer(**params)
  25. tfidf = vectorizer.fit(code_blocks)
  26. pickle.dump(tfidf, open(TFIDF_DIR, "wb"))
  27. print('TF-IDF model has been saved')
  28. code_blocks_tfidf = tfidf.transform(code_blocks)
  29. return code_blocks_tfidf
  30. def SVM_evaluate(df, code_blocks, tfidf_params, TFIDF_DIR, SVM_params):
  31. code_blocks_tfidf = tfidf_fit_transform(code_blocks, tfidf_params, TFIDF_DIR)
  32. X_train, X_test, y_train, y_test = train_test_split(code_blocks_tfidf, df[TAG_TO_PREDICT], test_size=0.3)
  33. # grid = {"C": [100]}
  34. # cv = KFold(n_splits=2, shuffle=True, random_state=241)
  35. model = SVC(kernel="linear", random_state=241)
  36. # gs = GridSearchCV(model, grid, scoring="accuracy", cv=cv, verbose=1, n_jobs=-1)
  37. # gs.fit(X_train[:25000], y_train.ravel()[:25000])
  38. # C = gs.best_params_.get('C')
  39. # model = SVC(**SVM_params)
  40. print("Train SVM params:", model.get_params())
  41. n_estimators = 10
  42. clf = BaggingClassifier(model, max_samples=1.0 / n_estimators, n_estimators=n_estimators)
  43. # clf = model
  44. print("starting training..")
  45. clf.fit(X_train, y_train)
  46. print("saving the model")
  47. pickle.dump(clf, open(MODEL_DIR, 'wb'))
  48. print("predicting on the test..")
  49. y_pred = clf.predict(X_test)
  50. accuracy = accuracy_score(y_test, y_pred)
  51. f1 = f1_score(y_test, y_pred, average='weighted')
  52. # confus_matrix = confusion_matrix(model, X_test, y_test)
  53. metrics = {'test_accuracy': accuracy
  54. , 'test_f1_score': f1}
  55. print(metrics)
  56. return metrics
  57. if __name__ == '__main__':
  58. GRAPH_VERSION = 3.1
  59. DATASET_PATH = './data/code_blocks_regex_graph_v{}.csv'.format(GRAPH_VERSION)
  60. MODEL_DIR = './models/svm_regex_graph_v{}.sav'.format(GRAPH_VERSION)
  61. TFIDF_DIR = './models/tfidf_svm_graph_v{}.pickle'.format(GRAPH_VERSION)
  62. CODE_COLUMN = 'code_block'
  63. TAG_TO_PREDICT = 'preprocessing'
  64. SCRIPT_DIR = __file__
  65. df, code_blocks = load_code_blocks(DATASET_PATH, CODE_COLUMN)
  66. nrows = df.shape[0]
  67. print("loaded")
  68. tfidf_params = {'min_df': 5
  69. , 'max_df': 0.3
  70. , 'smooth_idf': True}
  71. SVM_params = {'C':100
  72. , 'kernel':"linear"
  73. , 'random_state':241}
  74. data_meta = {'DATASET_PATH': DATASET_PATH
  75. ,'nrows': nrows
  76. ,'label': TAG_TO_PREDICT
  77. ,'model': MODEL_DIR
  78. ,'source': SCRIPT_DIR}
  79. with dagshub.dagshub_logger() as logger:
  80. print("evaluating..")
  81. metrics = SVM_evaluate(df, code_blocks, tfidf_params, TFIDF_DIR, SVM_params)
  82. print("saving the results..")
  83. logger.log_hyperparams(data_meta)
  84. logger.log_hyperparams(tfidf_params)
  85. logger.log_hyperparams(SVM_params)
  86. logger.log_metrics(metrics)
  87. print("finished")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...