Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_inference.py 2.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  1. import io
  2. import pytest
  3. import webdataset as wds
  4. import numpy as np
  5. import torch
  6. from deadtrees.deployment.inference import ONNXInference, PyTorchInference
  7. from PIL import Image
  8. # test_data = "tests/testdata/testdata.tar" # 13 samples
  9. test_data = "tests/testdata/train-balanced-short-000000.tar" # 128 samples
  10. bs = 4
  11. def image_decoder(data):
  12. with io.BytesIO(data) as stream:
  13. img = Image.open(stream)
  14. img.load()
  15. img = img.convert("RGB")
  16. result = np.asarray(img)
  17. result = np.array(result.transpose(2, 0, 1))
  18. return torch.tensor(result) / 255.0
  19. def mask_decoder(data):
  20. with io.BytesIO(data) as stream:
  21. img = Image.open(stream)
  22. img.load()
  23. img = img.convert("L")
  24. result = np.asarray(img)
  25. return torch.tensor(result)
  26. def semsegment_decoder(sample):
  27. sample = dict(sample)
  28. sample["rgb.png"] = image_decoder(sample["rgb.png"])
  29. sample["msk.png"] = mask_decoder(sample["msk.png"])
  30. return sample
  31. @pytest.fixture
  32. def sample_image():
  33. ds = (
  34. wds.Dataset(test_data)
  35. .map(semsegment_decoder)
  36. .rename(image="rgb.png", mask="msk.png")
  37. .to_tuple("image", "mask")
  38. )
  39. sample = next(iter(ds))
  40. return sample[0]
  41. @pytest.fixture
  42. def sample_batch():
  43. ds = (
  44. wds.Dataset(test_data)
  45. .map(semsegment_decoder)
  46. .rename(image="rgb.png", mask="msk.png")
  47. .to_tuple("image", "mask")
  48. .batched(bs)
  49. )
  50. sample = next(iter(ds))
  51. return sample[0]
  52. @pytest.fixture
  53. def pytorch_inference():
  54. return PyTorchInference("checkpoints/bestmodel.ckpt")
  55. @pytest.fixture
  56. def onnx_inference():
  57. return ONNXInference("checkpoints/bestmodel.onnx")
  58. def test_inference_single_size(sample_image):
  59. assert sample_image.shape == (3, 512, 512)
  60. def test_inference_batch_size(sample_batch):
  61. assert sample_batch.shape == (bs, 3, 512, 512)
  62. # sizes: bs, y, x
  63. def test_inference_pytorch_single_predict_size(pytorch_inference, sample_image):
  64. assert (pytorch_inference.run(sample_image)).shape == (512, 512)
  65. def test_inference_pytorch_batch_predict_size(pytorch_inference, sample_batch):
  66. assert (pytorch_inference.run(sample_batch)).shape == (bs, 512, 512)
  67. def test_inference_onnx_single_predict_size(onnx_inference, sample_image):
  68. sample_image_numpy = sample_image.detach().cpu().numpy()
  69. assert (onnx_inference.run(sample_image_numpy)).shape == (512, 512)
  70. def test_inference_onnx_batch_predict_size(onnx_inference, sample_batch):
  71. sample_batch_numpy = sample_batch.detach().cpu().numpy()
  72. assert (onnx_inference.run(sample_batch_numpy)).shape == (bs, 512, 512)
  73. # def test_inference_pytorch_batch():
  74. # pass
  75. # def test_inference_onnx_single():
  76. # pass
  77. # def test_inference_onnx_batch():
  78. # pass
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...