Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

multiclass_office.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # %%
  4. # %%
  5. import joblib
  6. import numpy as np
  7. import pandas as pd
  8. import re
  9. from sklearn.feature_extraction.text import TfidfVectorizer
  10. from sklearn.metrics import classification_report
  11. from sklearn.model_selection import train_test_split
  12. from sklearn.pipeline import Pipeline
  13. from sklearn.svm import LinearSVC
  14. import yaml
  15. with open("params.yaml", "r") as fd:
  16. params = yaml.safe_load(fd)
  17. docs = params["preprocessing"]["max_min_docs"]
  18. ngrams = params['preprocessing']['n_grams']
  19. #Load in data
  20. tweets = pd.read_csv('data/tweets.csv')
  21. #Drop tweets not in english
  22. tweets = tweets.loc[tweets['language'] == 'en']
  23. tweets['tweet'] = tweets['tweet'].str.replace(r'http\S+', '')
  24. tweets = tweets.loc[tweets['tweet'] != '']
  25. tweets = tweets.reset_index(drop=True)
  26. states = pd.read_csv('data/elected_officials.csv')
  27. states = states.melt(id_vars = ['State',
  28. 'StateAbbr',
  29. 'Name',
  30. 'Party',
  31. 'Inauguration',
  32. 'Title',
  33. 'office'],
  34. value_vars = ['officialTwitter',
  35. 'campaignTwitter',
  36. 'othertwitter'],
  37. var_name = 'account_type',
  38. value_name = 'twitter')
  39. states['twitter'] = states['twitter'].str.lower()
  40. tweets = tweets.merge(states, left_on = 'username', right_on = 'twitter')
  41. #Create numeric labels based on state names
  42. #Merge labels into MTG data frame
  43. labels = pd.DataFrame(tweets['office'].unique()).reset_index()
  44. #Add one because zero indexed
  45. labels['index'] = labels['index']+1
  46. labels.columns = ['label', 'office']
  47. tweets = tweets.merge(labels, on = 'office')
  48. #Select labels as targets
  49. y = tweets['label']
  50. #Select text columns as features
  51. X = tweets["tweet"]
  52. #Training test split 70/30
  53. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.5)
  54. #Preprocess text
  55. vectorizer = TfidfVectorizer(
  56. min_df=docs['smallest'],
  57. max_df=docs['largest'],
  58. stop_words="english",
  59. ngram_range = (ngrams['min'],ngrams['max'])
  60. )
  61. #Create pipeline with preprocessing and linear SVC
  62. pipe = Pipeline([
  63. ('preprocess', vectorizer),
  64. ('LinearSVC', LinearSVC())
  65. ])
  66. #Fit pipe to training data
  67. fitted_pipe = pipe.fit(X_train, y_train)
  68. #Export pickeled pipe
  69. joblib.dump(fitted_pipe, 'outputs/mc_office_pipe.pkl')
  70. #Generate predictions
  71. y_pred = pipe.predict(X_test)
  72. #Output metrics to JSON
  73. metrics = pd.DataFrame(classification_report(y_test, y_pred, output_dict=True))
  74. metrics["weighted avg"].to_json("metrics/mc_office_metrics.json")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...