Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

nms_benchmarking.py 9.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
  1. import infery
  2. import pandas as pd
  3. import numpy as np
  4. from enum import Enum
  5. import time
  6. from datetime import datetime
  7. import os
  8. import sys
  9. import gc
  10. import torch
  11. try:
  12. from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
  13. from super_gradients.training.utils.detection_utils import NMS_Type
  14. SG_IMPORTED_SUCCESSFULLY = True
  15. except ImportError as e:
  16. print(f'\n\nWARNING: Failed to import super-gradients - {e}\n\n')
  17. SG_IMPORTED_SUCCESSFULLY = False
  18. BATCH_SIZES = [1, 8, 16, 32]
  19. WARMUP_REPETITIONS = 100
  20. NMS_CONFIDENCE_THRESHOLD = 0.5
  21. NMS_IOU = 0.65
  22. class ModelArchs(Enum):
  23. YOLOX_N = 'yolox_n'
  24. YOLOX_T = 'yolox_t'
  25. YOLOX_S = 'yolox_s'
  26. class NMSTypes(Enum):
  27. FIRST_100 = 'first_100'
  28. NO_NMS = 'no_nms'
  29. BATCHED_NMS = 'batched_nms'
  30. class NMSBenchmarker:
  31. def __init__(self, _model_directory_base):
  32. self._model_directory_base = _model_directory_base
  33. """
  34. Expected Model Directory Structure:
  35. |
  36. -- yolox_s
  37. | |
  38. | ---- first_100.engine
  39. | ---- no_nms.engine
  40. | ---- batched_nms.engine
  41. -- yolox_t
  42. | |
  43. | ---- first_100.engine
  44. | ---- no_nms.engine
  45. | ---- batched_nms.engine
  46. -- yolox_n
  47. |
  48. ---- first_100.engine
  49. ---- no_nms.engine
  50. ---- batched_nms.engine
  51. """
  52. def _benchmark_internal(self, model_archs, batch_sizes, nms_type, nms_device='cpu', torch_nms='iterative'):
  53. self.cleanup()
  54. results_dict = {}
  55. # If NO_NMS is compiled, use the Torch NMS module with the passed NMS device and type
  56. benchmark_method = self._benchmark_helper_in_model_nms
  57. nms_callback = None
  58. if nms_type == NMSTypes.NO_NMS:
  59. benchmark_method = self._benchmark_helper_torch_cpu if nms_device == 'cpu' else \
  60. self._benchmark_helper_torch_gpu
  61. nms_callback = YoloPostPredictionCallback(conf=NMS_CONFIDENCE_THRESHOLD,
  62. iou=NMS_IOU, nms_type=NMS_Type(torch_nms))
  63. for arch in model_archs:
  64. # Prep model and results bookkeeping
  65. results_dict[arch] = {}
  66. model_path = self.path_for_model(model_arch=arch, nms_type=nms_type)
  67. loaded_model = infery.load(model_path=model_path, framework_type='trt')
  68. # Start benchmarking different batch sizes
  69. for bs in batch_sizes:
  70. print(f'{arch} --- {bs}')
  71. data_loader = self.get_coco_data_loader()
  72. results_dict[arch][bs] = {}
  73. # Warmup
  74. dummy_input = np.random.rand(bs, *(loaded_model.input_dims[0])).astype(np.float32)
  75. for _ in range(WARMUP_REPETITIONS):
  76. x = loaded_model.predict(dummy_input, output_device=nms_device)
  77. # Benchmark
  78. times = []
  79. for x in data_loader:
  80. x = (x[0].numpy())
  81. # Bug with setting BS of dataloader dynamically
  82. for i in range(0, 64, bs):
  83. y = x[i:i+bs, :, :, :]
  84. # We're loaded and the input is converted - now time to benchmark (E2E - i.e., CPU -> CPU)
  85. benchmark_method(y, loaded_model, times, nms_callback)
  86. results_dict[arch][bs]['latency'] = sum(times)/len(times)
  87. results_dict[arch][bs]['throughput'] = len(times)*bs/sum(times)
  88. results_dict[arch][bs]['date'] = datetime.now()
  89. self.cleanup()
  90. return results_dict
  91. def _benchmark_helper_torch_gpu(self, x, loaded_model, times, nms_callback):
  92. start = time.perf_counter()
  93. x = loaded_model.predict(x, output_device='gpu')
  94. x = nms_callback(x[-1])
  95. times.append(time.perf_counter() - start)
  96. def _benchmark_helper_torch_cpu(self, x, loaded_model, times, nms_callback):
  97. start = time.perf_counter()
  98. x = loaded_model.predict(x)
  99. x = nms_callback(torch.from_numpy(x[-1]))
  100. times.append(time.perf_counter() - start)
  101. def _benchmark_helper_in_model_nms(self, x, loaded_model, times, nms_callback):
  102. start = time.perf_counter()
  103. x = loaded_model.predict(x)
  104. times.append(time.perf_counter() - start)
  105. def no_nms_first_100_benchmarks(self, model_archs=None, batch_sizes=None):
  106. model_archs, batch_sizes = self._valid_archs_and_batchs(model_archs, batch_sizes)
  107. return self._benchmark_internal(model_archs=model_archs, batch_sizes=batch_sizes, nms_type=NMSTypes.FIRST_100)
  108. def trt_batched_nms(self, model_archs=None, batch_sizes=None):
  109. model_archs, batch_sizes = self._valid_archs_and_batchs(model_archs, batch_sizes)
  110. return self._benchmark_internal(model_archs=model_archs, batch_sizes=batch_sizes, nms_type=NMSTypes.BATCHED_NMS)
  111. def native_torch_on_cpu(self, model_archs=None, batch_sizes=None, torch_nms='iterative'):
  112. model_archs, batch_sizes = self._valid_archs_and_batchs(model_archs, batch_sizes)
  113. return self._benchmark_internal(model_archs=model_archs, batch_sizes=batch_sizes, nms_device='cpu',
  114. nms_type=NMSTypes.NO_NMS)
  115. def native_torch_on_gpu(self, model_archs=None, batch_sizes=None, torch_nms='iterative'):
  116. model_archs, batch_sizes = self._valid_archs_and_batchs(model_archs, batch_sizes)
  117. return self._benchmark_internal(model_archs=model_archs, batch_sizes=batch_sizes, nms_device='gpu',
  118. nms_type=NMSTypes.NO_NMS)
  119. def path_for_model(self, model_arch, nms_type):
  120. return os.path.join(self._model_directory_base, model_arch.value, f'{nms_type.value}.engine')
  121. @staticmethod
  122. def persist_result_dict_to_csv(benchmark_type, results_dict, export_path, append_to_existing=True):
  123. df_dict = {
  124. 'benchmark_type': [],
  125. 'model': [],
  126. 'latency': [],
  127. 'throughput': [],
  128. 'date': [],
  129. 'batch_size': []
  130. }
  131. for arch in results_dict:
  132. for bs in results_dict[arch]:
  133. df_dict['latency'].append(results_dict[arch][bs]['latency'])
  134. df_dict['throughput'].append(results_dict[arch][bs]['throughput'])
  135. df_dict['date'].append(results_dict[arch][bs]['date'])
  136. df_dict['benchmark_type'].append(benchmark_type)
  137. df_dict['model'].append(arch.value)
  138. df_dict['batch_size'].append(bs)
  139. new_results_df = pd.DataFrame(data=df_dict)
  140. if append_to_existing and os.path.exists(export_path):
  141. old_results_df = pd.read_csv(export_path)
  142. new_results_df = new_results_df.append(old_results_df, ignore_index=True)
  143. new_results_df.to_csv(export_path, index=False)
  144. @staticmethod
  145. def get_coco_data_loader():
  146. from super_gradients.training.dataloaders.dataloader_factory import coco2017_val_yolox
  147. return coco2017_val_yolox()
  148. def _valid_archs_and_batchs(self, model_archs, batch_sizes):
  149. # If not specified, benchmark all model architectures and all batch sizes
  150. model_archs = model_archs or list(ModelArchs)
  151. batch_sizes = batch_sizes or BATCH_SIZES
  152. return model_archs, batch_sizes
  153. @staticmethod
  154. def cleanup():
  155. gc.collect()
  156. del gc.garbage[:]
  157. gc.collect()
  158. if __name__ == '__main__':
  159. if len(sys.argv) != 3:
  160. print('USAGE: [PATH_TO_MODEL_DIR] [PATH_TO_RESULTS_FILE]')
  161. exit(1)
  162. # ------------- CONSTANTS ------------- #
  163. benchmarker = NMSBenchmarker(sys.argv[1])
  164. results_path = sys.argv[2]
  165. # ------------- BENCHMARK ------------- #
  166. results_dict = benchmarker.no_nms_first_100_benchmarks()
  167. benchmarker.persist_result_dict_to_csv('first_100',
  168. results_dict=results_dict,
  169. export_path=results_path,
  170. append_to_existing=True)
  171. results_dict = benchmarker.trt_batched_nms()
  172. benchmarker.persist_result_dict_to_csv('trt_batched',
  173. results_dict=results_dict,
  174. export_path=results_path,
  175. append_to_existing=True)
  176. if SG_IMPORTED_SUCCESSFULLY:
  177. results_dict = benchmarker.native_torch_on_cpu(torch_nms=NMS_Type.ITERATIVE)
  178. benchmarker.persist_result_dict_to_csv('torch_cpu',
  179. results_dict=results_dict,
  180. export_path=results_path,
  181. append_to_existing=True)
  182. results_dict = benchmarker.native_torch_on_gpu(torch_nms=NMS_Type.ITERATIVE)
  183. benchmarker.persist_result_dict_to_csv('torch_gpu',
  184. results_dict=results_dict,
  185. export_path=results_path,
  186. append_to_existing=True)
  187. # TODO: SUPPORT MATRIX NMS? CURRENTLY SEEMS BUGGY.
  188. # TODO: results_dict = benchmarker.native_torch_on_gpu(torch_nms=NMS_Type.MATRIX)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...