Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

multiclass_state.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # %%
  4. import joblib
  5. import numpy as np
  6. import pandas as pd
  7. import re
  8. from sklearn.feature_extraction.text import TfidfVectorizer
  9. from sklearn.metrics import classification_report
  10. from sklearn.model_selection import train_test_split
  11. from sklearn.pipeline import Pipeline
  12. from sklearn.svm import LinearSVC
  13. import yaml
  14. with open("params.yaml", "r") as fd:
  15. params = yaml.safe_load(fd)
  16. docs = params["preprocessing"]["max_min_docs"]
  17. ngrams = params['preprocessing']['n_grams']
  18. #Load in data
  19. tweets = pd.read_csv('data/tweets.csv')
  20. #Drop tweets not in english
  21. tweets = tweets.loc[tweets['language'] == 'en']
  22. tweets['tweet'] = tweets['tweet'].str.replace(r'http\S+', '')
  23. tweets = tweets.loc[tweets['tweet'] != '']
  24. tweets = tweets.reset_index(drop=True)
  25. states = pd.read_csv('data/elected_officials.csv')
  26. states = states.melt(id_vars = ['State',
  27. 'StateAbbr',
  28. 'Name',
  29. 'Party',
  30. 'Inauguration',
  31. 'Title',
  32. 'office'],
  33. value_vars = ['officialTwitter',
  34. 'campaignTwitter',
  35. 'othertwitter'],
  36. var_name = 'account_type',
  37. value_name = 'twitter')
  38. states['twitter'] = states['twitter'].str.lower()
  39. tweets = tweets.merge(states, left_on = 'username', right_on = 'twitter')
  40. #Create numeric labels based on state names
  41. #Merge labels into MTG data frame
  42. labels = pd.DataFrame(tweets['State'].unique()).reset_index()
  43. #Add one because zero indexed
  44. labels['index'] = labels['index']+1
  45. labels.columns = ['label', 'State']
  46. tweets = tweets.merge(labels, on = 'State')
  47. #Select labels as targets
  48. y = tweets['label']
  49. #Select text columns as features
  50. X = tweets["tweet"]
  51. #Training test split 70/30
  52. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.5)
  53. #Preprocess text
  54. vectorizer = TfidfVectorizer(
  55. min_df=docs['smallest'],
  56. max_df=docs['largest'],
  57. stop_words="english",
  58. ngram_range = (ngrams['min'],ngrams['max'])
  59. )
  60. #Create pipeline with preprocessing and linear SVC
  61. pipe = Pipeline([
  62. ('preprocess', vectorizer),
  63. ('LinearSVC', LinearSVC())
  64. ])
  65. #Fit pipe to training data
  66. fitted_pipe = pipe.fit(X_train, y_train)
  67. #Export pickeled pipe
  68. joblib.dump(fitted_pipe, 'outputs/mc_state_pipe.pkl')
  69. #Generate predictions
  70. y_pred = pipe.predict(X_test)
  71. #Output metrics to JSON
  72. metrics = pd.DataFrame(classification_report(y_test, y_pred, output_dict=True))
  73. metrics["weighted avg"].to_json("metrics/mc_state_metrics.json")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...