Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

create_train_test.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import pandas as pd
  2. def check_sample_distribution(df, sample, diff_thrsh=.05, check_cols=[], verbose=True):
  3. '''
  4. check if the distribution of the sample's col and df's col is sufficiently
  5. close. Default tolerance is 1% difference
  6. '''
  7. if not check_cols:
  8. check_cols = df.columns
  9. pop_n = len(df)
  10. s_n = len(sample)
  11. sample_miss = {}
  12. big_pct_diff = {}
  13. for col in check_cols:
  14. pop_group = df[col].value_counts(dropna=False)/pop_n
  15. s_group = sample[col].value_counts(dropna=False)/s_n
  16. temp_miss = {}
  17. temp_diff = {}
  18. for k in pop_group.keys():
  19. if k not in s_group.keys():
  20. if verbose:
  21. print('{0} group for {2} column is missing entirely from the sample while population has {1}'.format(k, pop_group[k], col))
  22. temp_miss[k] = pop_group[k]
  23. else:
  24. pct_diff = abs(pop_group[k] - s_group[k])/pop_group[k]
  25. if pct_diff > diff_thrsh:
  26. temp_diff[k] = pct_diff
  27. if temp_miss:
  28. sample_miss[col] = temp_miss
  29. if temp_diff:
  30. big_pct_diff[col] = temp_diff
  31. if sample_miss or big_pct_diff:
  32. print("There is a sampling concern")
  33. def check_not_same_loans(tr, te):
  34. return bool(len(set(tr['id']).intersection(set(te['id']))) == 0)
  35. def check_all_loans_accounted(tr, te, to):
  36. return bool(tr.shape[0] + te.shape[0] == to.shape[0])
  37. def check_same_n_instances(df1, df2):
  38. return bool(df1.shape[0] == df2.shape[0])
  39. def check_same_n_cols(df1, df2):
  40. return bool(df1.shape[1] == df2.shape[1])
  41. def check_train_test_testable(train, test, testable, train1, test1, testable1):
  42. '''
  43. First set for loan_info, second set for eval_loan_info
  44. '''
  45. print(train.shape, test.shape, testable.shape, train1.shape, test1.shape, testable1.shape)
  46. assert check_not_same_loans(train, test)
  47. assert check_all_loans_accounted(train, test, testable)
  48. assert check_not_same_loans(train1, test1)
  49. assert check_all_loans_accounted(train1, test1, testable1)
  50. assert check_same_n_instances(train, train1)
  51. assert check_same_n_instances(test, test1)
  52. assert check_same_n_instances(testable, testable1)
  53. assert check_same_n_cols(train, test)
  54. assert check_same_n_cols(train1, test1)
  55. return True
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...