Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_webdataloader.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
  1. # test the basic data loading stuff
  2. #
  3. import io
  4. import webdataset as wds
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. # test_data = "tests/testdata/testdata.tar" # 13 samples
  9. test_data = "tests/testdata/train-balanced-short-000000.tar" # 128 samples
  10. def count_samples_tuple(source, *args, n=1000):
  11. count = 0
  12. for i, sample in enumerate(iter(source)):
  13. if i >= n:
  14. break
  15. assert isinstance(sample, (tuple, dict, list)), (type(sample), sample)
  16. for f in args:
  17. assert f(sample)
  18. count += 1
  19. return count
  20. def test_dataset():
  21. ds = wds.Dataset(test_data)
  22. assert count_samples_tuple(ds) == 64
  23. def test_dataset_shuffle_extract():
  24. ds = wds.Dataset(test_data).shuffle(5).to_tuple("msk.png rgb.png")
  25. assert count_samples_tuple(ds) == 64
  26. def test_dataset_pipe_cat():
  27. ds = wds.Dataset(f"pipe:cat {test_data}").shuffle(5).to_tuple("msk.png rgb.png")
  28. assert count_samples_tuple(ds) == 64
  29. def test_slice():
  30. ds = wds.Dataset(test_data).slice(10)
  31. assert count_samples_tuple(ds) == 10
  32. def test_rename():
  33. ds = wds.Dataset(test_data).rename(image="rgb.png", mask="msk.png")
  34. sample = next(iter(ds))
  35. assert set(sample.keys()) == {"image", "mask"}
  36. def test_torch_sample_decoder():
  37. def image_decoder(data):
  38. with io.BytesIO(data) as stream:
  39. img = Image.open(stream)
  40. img.load()
  41. img = img.convert("RGB")
  42. result = np.asarray(img)
  43. result = np.array(result.transpose(2, 0, 1))
  44. return torch.tensor(result) / 255.0
  45. def mask_decoder(data):
  46. with io.BytesIO(data) as stream:
  47. img = Image.open(stream)
  48. img.load()
  49. img = img.convert("L")
  50. result = np.asarray(img)
  51. return torch.tensor(result)
  52. def semsegment_decoder(sample):
  53. sample = dict(sample)
  54. sample["rgb.png"] = image_decoder(sample["rgb.png"])
  55. sample["msk.png"] = mask_decoder(sample["msk.png"])
  56. return sample
  57. ds = (
  58. wds.Dataset(test_data)
  59. .map(semsegment_decoder)
  60. .rename(image="rgb.png", mask="msk.png")
  61. .to_tuple("image", "mask")
  62. )
  63. image, mask = next(iter(ds))
  64. assert (image.shape, mask.shape) == ((3, 512, 512), (512, 512))
  65. def test_torch_map_dict_decoder():
  66. def image_decoder(data):
  67. with io.BytesIO(data) as stream:
  68. img = Image.open(stream)
  69. img.load()
  70. img = img.convert("RGB")
  71. result = np.asarray(img)
  72. result = np.array(result.transpose(2, 0, 1))
  73. return torch.tensor(result) / 255.0
  74. def mask_decoder(data):
  75. with io.BytesIO(data) as stream:
  76. img = Image.open(stream)
  77. img.load()
  78. img = img.convert("L")
  79. result = np.asarray(img)
  80. return torch.tensor(result)
  81. ds = (
  82. wds.Dataset(test_data)
  83. .rename(image="rgb.png", mask="msk.png")
  84. .map_dict(image=image_decoder, mask=mask_decoder)
  85. .to_tuple("image", "mask")
  86. )
  87. image, mask = next(iter(ds))
  88. assert (image.shape, mask.shape) == ((3, 512, 512), (512, 512))
  89. def test_torch_map_dict_batched_decoder():
  90. bs = 8
  91. def image_decoder(data):
  92. with io.BytesIO(data) as stream:
  93. img = Image.open(stream)
  94. img.load()
  95. img = img.convert("RGB")
  96. result = np.asarray(img)
  97. result = np.array(result.transpose(2, 0, 1))
  98. return torch.tensor(result) / 255.0
  99. def mask_decoder(data):
  100. with io.BytesIO(data) as stream:
  101. img = Image.open(stream)
  102. img.load()
  103. img = img.convert("L")
  104. result = np.asarray(img)
  105. return torch.tensor(result)
  106. ds = (
  107. wds.Dataset(test_data)
  108. .rename(image="rgb.png", mask="msk.png")
  109. .map_dict(image=image_decoder, mask=mask_decoder)
  110. .to_tuple("image", "mask")
  111. .batched(bs, partial=False)
  112. )
  113. image, mask = next(iter(ds))
  114. assert (image.shape, mask.shape) == ((bs, 3, 512, 512), (bs, 512, 512))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...