Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_data_splitting.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  1. import io
  2. import math
  3. from functools import reduce
  4. import pytest
  5. import numpy as np
  6. import pandas as pd
  7. from deadtrees.utils.data_handling import split_df
  8. TESTDATA = """tile,frac,status
  9. ortho_ms_2019_EPSG3044_032_070_017,0.0,0
  10. ortho_ms_2019_EPSG3044_032_070_018,0.0,0
  11. ortho_ms_2019_EPSG3044_032_070_046,0.23,1
  12. ortho_ms_2019_EPSG3044_032_070_047,0.58,1
  13. ortho_ms_2019_EPSG3044_032_071_032,0.48,1
  14. ortho_ms_2019_EPSG3044_032_071_033,0.01,1
  15. ortho_ms_2019_EPSG3044_032_071_049,0.22,1
  16. ortho_ms_2019_EPSG3044_032_071_050,0.29,1
  17. ortho_ms_2019_EPSG3044_032_071_052,0.3,1
  18. ortho_ms_2019_EPSG3044_032_071_053,0.4,1
  19. ortho_ms_2019_EPSG3044_032_071_056,0.67,1
  20. ortho_ms_2019_EPSG3044_032_071_057,0.39,1
  21. ortho_ms_2019_EPSG3044_032_071_058,1.64,1
  22. """
  23. eps = 1e-7
  24. np.random.seed(42)
  25. class TestSplitDf:
  26. # datasets to check
  27. data_fake = pd.DataFrame(
  28. {
  29. "tile": [f"fake_tile_{i:03d}.tif" for i in range(100)],
  30. "frac": np.random.gamma(9, 0.5, size=100) + eps,
  31. "status": np.ones(100, dtype=int),
  32. }
  33. )
  34. data_bad = pd.read_csv(io.StringIO(TESTDATA))
  35. data = pd.read_csv(io.StringIO(TESTDATA)).query("frac > 0")
  36. @pytest.mark.parametrize("n", [0, 100])
  37. def test_catch_invalid_size(self, n):
  38. with pytest.raises(ValueError):
  39. split_df(self.data, n)
  40. def test_catch_tiles_without_deadtrees(self):
  41. with pytest.raises(ValueError):
  42. split_df(self.data_bad, 3)
  43. def test_total_size_unchanged(self):
  44. result = split_df(self.data, 3)
  45. assert len(reduce(lambda z, y: z + y, result)) == len(self.data)
  46. def test_number_of_partitions_as_requested(self):
  47. result = split_df(self.data, 3)
  48. assert len(result) == math.ceil(len(self.data) / 3)
  49. def test_partitioned_totals_approx_equal(self):
  50. # dodgy, hand-crafted, and should be replaced by something rigid
  51. splits = split_df(self.data_fake, 10)
  52. totals = [
  53. self.data_fake[self.data_fake.tile.isin(s)].frac.sum() for s in splits
  54. ]
  55. assert [45] * len(totals) == pytest.approx(totals, abs=5)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...