Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

split_train_test.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  1. """
  2. Split the dataset into train and test set.
  3. Routine Listings
  4. ----------------
  5. get_params()
  6. Get the DVC stage parameters.
  7. split(seed, test_ratio, input_path, train, test)
  8. Split dataset into train and test set.
  9. """
  10. import sys
  11. import dask
  12. import dask.distributed
  13. import pandas as pd
  14. from sklearn.model_selection import train_test_split
  15. import conf
  16. def get_params():
  17. """Get the DVC stage parameters."""
  18. return {
  19. 'test_ratio': 0.33,
  20. 'seed': 42}
  21. @dask.delayed
  22. def split(seed, test_ratio, input_path, train, test):
  23. """Split dataset into train and test set."""
  24. def sub_df_by_ids(df, ids):
  25. df_train_order = pd.DataFrame(data={'id': ids})
  26. return df.merge(df_train_order, on='id')
  27. def train_test_split_df(df, ids, test_ratio, seed):
  28. train_ids, test_ids = train_test_split(
  29. ids, test_size=test_ratio, random_state=seed)
  30. return sub_df_by_ids(df, train_ids), sub_df_by_ids(df, test_ids)
  31. df = pd.read_csv(
  32. input_path,
  33. encoding='utf-8',
  34. header=None,
  35. delimiter='\t',
  36. names=['id', 'label', 'text']
  37. )
  38. df_positive = df[df['label'] == 1]
  39. df_negative = df[df['label'] == 0]
  40. sys.stderr.write('Positive size {}, negative size {}\n'.format(
  41. df_positive.shape[0],
  42. df_negative.shape[0]
  43. ))
  44. df_pos_train, df_pos_test = train_test_split_df(
  45. df, df_positive.id, test_ratio, seed)
  46. df_neg_train, df_neg_test = train_test_split_df(
  47. df, df_negative.id, test_ratio, seed)
  48. df_train = pd.concat([df_pos_train, df_neg_train])
  49. df_test = pd.concat([df_pos_test, df_neg_test])
  50. df_train.to_csv(train, sep='\t', header=False, index=False)
  51. df_test.to_csv(test, sep='\t', header=False, index=False)
  52. if __name__ == '__main__':
  53. client = dask.distributed.Client('localhost:8786')
  54. INPUT_DATASET_TSV_PATH = conf.data_dir/'xml_to_tsv'/'Posts.tsv'
  55. dvc_stage_name = __file__.strip('.py')
  56. print(f'dvc_stage_name: {dvc_stage_name}')
  57. STAGE_OUTPUT_PATH = conf.data_dir/dvc_stage_name
  58. conf.remote_mkdir(STAGE_OUTPUT_PATH).compute()
  59. OUTPUT_TRAIN_TSV_PATH = STAGE_OUTPUT_PATH/'Posts-train.tsv'
  60. OUTPUT_TEST_TSV_PATH = STAGE_OUTPUT_PATH/'Posts-test.tsv'
  61. config = get_params()
  62. TEST_RATIO = config['test_ratio']
  63. SEED = config['seed']
  64. split(SEED, TEST_RATIO, INPUT_DATASET_TSV_PATH,
  65. OUTPUT_TRAIN_TSV_PATH, OUTPUT_TEST_TSV_PATH).compute()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...