Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

split_train_test.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  1. import sys
  2. import dask
  3. import dask.distributed
  4. import pandas as pd
  5. from sklearn.model_selection import train_test_split
  6. import conf
  7. client = dask.distributed.Client('localhost:8786')
  8. INPUT_PATH = conf.source_tsv
  9. TRAIN = conf.train_tsv
  10. TEST = conf.test_tsv
  11. @dask.delayed
  12. def workflow(seed, test_ratio, input_path, train, test):
  13. def sub_df_by_ids(df, ids):
  14. df_train_order = pd.DataFrame(data={'id': ids})
  15. return df.merge(df_train_order, on='id')
  16. def train_test_split_df(df, ids, test_ratio, seed):
  17. train_ids, test_ids = train_test_split(
  18. ids, test_size=test_ratio, random_state=seed)
  19. return sub_df_by_ids(df, train_ids), sub_df_by_ids(df, test_ids)
  20. df = pd.read_csv(
  21. input_path,
  22. encoding='utf-8',
  23. header=None,
  24. delimiter='\t',
  25. names=['id', 'label', 'text']
  26. )
  27. df_positive = df[df['label'] == 1]
  28. df_negative = df[df['label'] == 0]
  29. sys.stderr.write('Positive size {}, negative size {}\n'.format(
  30. df_positive.shape[0],
  31. df_negative.shape[0]
  32. ))
  33. df_pos_train, df_pos_test = train_test_split_df(
  34. df, df_positive.id, test_ratio, seed)
  35. df_neg_train, df_neg_test = train_test_split_df(
  36. df, df_negative.id, test_ratio, seed)
  37. df_train = pd.concat([df_pos_train, df_neg_train])
  38. df_test = pd.concat([df_pos_test, df_neg_test])
  39. df_train.to_csv(train, sep='\t', header=False, index=False)
  40. df_test.to_csv(test, sep='\t', header=False, index=False)
  41. if len(sys.argv) != 3:
  42. sys.stderr.write('Arguments error. Usage:\n')
  43. sys.stderr.write('\tpython split_train_test.py TEST_RATIO SEED\n')
  44. sys.stderr.write(
  45. '\t\tTEST_RATIO - train set ratio (double). Example: 0.3\n')
  46. sys.stderr.write('\t\tSEED - random state (integer). Example: 20170423\n')
  47. sys.exit(1)
  48. test_ratio = float(sys.argv[1])
  49. seed = int(sys.argv[2])
  50. workflow(seed, test_ratio, INPUT_PATH, TRAIN, TEST).compute()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...