Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

silero_vad.py 13 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. import logging
  8. from pathlib import Path
  9. import queue
  10. import random
  11. import time
  12. from argparse import ArgumentParser, Namespace
  13. from os import SEEK_END
  14. from typing import Any, List, Optional, Union
  15. import numpy as np
  16. import torch
  17. import soundfile
  18. from seamless_communication.streaming.agents.common import (
  19. AgentStates,
  20. EarlyStoppingMixin,
  21. )
  22. from simuleval.agents import SpeechToSpeechAgent
  23. from simuleval.agents.actions import Action, ReadAction, WriteAction
  24. from simuleval.data.segments import EmptySegment, Segment, SpeechSegment
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  28. )
  29. logger = logging.getLogger(__name__)
  30. SPEECH_PROB_THRESHOLD = 0.6
  31. class SileroVADStates(EarlyStoppingMixin, AgentStates): # type: ignore
  32. def __init__(self, args: Namespace) -> None:
  33. self.model, utils = torch.hub.load(
  34. repo_or_dir="snakers4/silero-vad",
  35. model="silero_vad",
  36. force_reload=False,
  37. onnx=False,
  38. )
  39. (
  40. self.get_speech_timestamps,
  41. self.save_audio,
  42. self.read_audio,
  43. self.VADIterator,
  44. self.collect_chunks,
  45. ) = utils
  46. self.silence_limit_ms = args.silence_limit_ms
  47. self.speech_soft_limit_ms = args.speech_soft_limit_ms
  48. self.window_size_samples = args.window_size_samples
  49. self.chunk_size_samples = args.chunk_size_samples
  50. self.sample_rate = args.sample_rate
  51. self.init_speech_prob = args.init_speech_prob
  52. self.debug = args.debug
  53. self.test_input_segments_wav = None
  54. self.debug_log(args)
  55. self.input_queue: queue.Queue[Segment] = queue.Queue()
  56. self.next_input_queue: queue.Queue[Segment] = queue.Queue()
  57. super().__init__()
  58. def clear_queues(self) -> None:
  59. while not self.input_queue.empty():
  60. self.input_queue.get_nowait()
  61. self.input_queue.task_done()
  62. # move everything from next_input_queue to input_queue
  63. while not self.next_input_queue.empty():
  64. chunk = self.next_input_queue.get_nowait()
  65. self.next_input_queue.task_done()
  66. self.input_queue.put_nowait(chunk)
  67. def reset(self) -> None:
  68. super().reset()
  69. # TODO: in seamless_server, report latency for each new segment
  70. self.first_input_ts: Optional[float] = None
  71. self.silence_acc_ms = 0
  72. self.speech_acc_ms = 0
  73. self.input_chunk: np.ndarray[Any, np.dtype[np.int16]] = np.empty(
  74. 0, dtype=np.int16
  75. )
  76. self.is_fresh_state = True
  77. self.clear_queues()
  78. self.model.reset_states()
  79. self.consecutive_silence_decay_count = 0
  80. def reset_early(self) -> None:
  81. """
  82. Don't reset state before EOS
  83. """
  84. pass
  85. def get_speech_prob_from_np_float32(
  86. self, segment: np.ndarray[Any, np.dtype[np.float32]]
  87. ) -> List[Any]:
  88. t = torch.from_numpy(segment)
  89. speech_probs = []
  90. # TODO: run self.model in batch?
  91. for i in range(0, len(t), self.window_size_samples):
  92. chunk = t[i : i + self.window_size_samples]
  93. if len(chunk) < self.window_size_samples:
  94. break
  95. speech_prob = self.model(chunk, self.sample_rate).item()
  96. speech_probs.append(speech_prob)
  97. return speech_probs
  98. def debug_log(self, m: Any) -> None:
  99. if self.debug:
  100. logger.info(m)
  101. def process_speech(
  102. self,
  103. segment: Union[np.ndarray[Any, np.dtype[np.float32]], Segment],
  104. tgt_lang: Optional[str] = None,
  105. ) -> None:
  106. """
  107. Process a full or partial speech chunk
  108. """
  109. queue = self.input_queue
  110. if self.source_finished:
  111. # current source is finished, but next speech starts to come in already
  112. self.debug_log("use next_input_queue")
  113. queue = self.next_input_queue
  114. if self.first_input_ts is None:
  115. self.first_input_ts = time.time() * 1000
  116. while len(segment) > 0:
  117. # add chunks to states.buffer
  118. i = self.chunk_size_samples - len(self.input_chunk)
  119. self.input_chunk = np.concatenate((self.input_chunk, segment[:i]))
  120. segment = segment[i:]
  121. self.is_fresh_state = False
  122. if len(self.input_chunk) == self.chunk_size_samples:
  123. queue.put_nowait(
  124. SpeechSegment(
  125. content=self.input_chunk, finished=False, tgt_lang=tgt_lang
  126. )
  127. )
  128. self.input_chunk = np.empty(0, dtype=np.int16)
  129. def check_silence_acc(self, tgt_lang: Optional[str] = None) -> None:
  130. silence_limit_ms = self.silence_limit_ms
  131. if self.speech_acc_ms >= self.speech_soft_limit_ms:
  132. self.debug_log("increase speech threshold")
  133. silence_limit_ms = self.silence_limit_ms // 2
  134. self.debug_log(f"silence_acc_ms: {self.silence_acc_ms}")
  135. if self.silence_acc_ms >= silence_limit_ms:
  136. self.debug_log("=== end of segment")
  137. # source utterance finished
  138. self.silence_acc_ms = 0
  139. self.speech_acc_ms = 0
  140. if self.input_chunk.size > 0:
  141. # flush partial input_chunk
  142. self.input_queue.put_nowait(
  143. SpeechSegment(
  144. content=self.input_chunk, tgt_lang=tgt_lang, finished=True
  145. )
  146. )
  147. self.input_chunk = np.empty(0, dtype=np.int16)
  148. self.input_queue.put_nowait(EmptySegment(finished=True))
  149. self.source_finished = True
  150. self.debug_write_wav(np.empty(0, dtype=np.int16), finished=True)
  151. def decay_silence_acc_ms(self) -> None:
  152. if self.consecutive_silence_decay_count <= 2:
  153. self.silence_acc_ms = self.silence_acc_ms // 2
  154. self.consecutive_silence_decay_count += 1
  155. def update_source(
  156. self, segment: Union[np.ndarray[Any, np.dtype[np.float32]], Segment]
  157. ) -> None:
  158. """
  159. Default value for the segment in the update_source method is a segment
  160. Class, for some reason this interface didn't align with other interfaces
  161. Adding this change here to support both np.ndarray and Segment class
  162. """
  163. tgt_lang = None
  164. if isinstance(segment, SpeechSegment):
  165. self.sample_rate = segment.sample_rate
  166. if hasattr(segment, "tgt_lang") and segment.tgt_lang is not None:
  167. tgt_lang = segment.tgt_lang
  168. if isinstance(segment.content, np.ndarray):
  169. segment = np.array(segment.content, dtype=np.float32)
  170. else:
  171. segment = segment.content
  172. speech_probs = self.get_speech_prob_from_np_float32(segment)
  173. chunk_size_ms = len(segment) * 1000 / self.sample_rate
  174. window_size_ms = self.window_size_samples * 1000 / self.sample_rate
  175. consecutive_silence_decay = False
  176. if self.is_fresh_state and self.init_speech_prob > 0:
  177. threshold = SPEECH_PROB_THRESHOLD + self.init_speech_prob
  178. else:
  179. threshold = SPEECH_PROB_THRESHOLD
  180. if all(i <= threshold for i in speech_probs):
  181. if self.source_finished:
  182. return
  183. self.debug_log("got silent chunk")
  184. if not self.is_fresh_state:
  185. self.silence_acc_ms += chunk_size_ms
  186. self.check_silence_acc(tgt_lang)
  187. return
  188. elif speech_probs[-1] <= threshold:
  189. self.debug_log("=== start of silence chunk")
  190. # beginning = speech, end = silence
  191. # pass to process_speech and accumulate silence
  192. self.speech_acc_ms += chunk_size_ms
  193. consecutive_silence_decay = True
  194. self.decay_silence_acc_ms()
  195. self.process_speech(segment, tgt_lang)
  196. # accumulate contiguous silence
  197. for i in range(len(speech_probs) - 1, -1, -1):
  198. if speech_probs[i] > threshold:
  199. break
  200. self.silence_acc_ms += window_size_ms
  201. self.check_silence_acc(tgt_lang)
  202. elif speech_probs[0] <= threshold:
  203. self.debug_log("=== start of speech chunk")
  204. # beginning = silence, end = speech
  205. # accumulate silence , pass next to process_speech
  206. for i in range(0, len(speech_probs)):
  207. if speech_probs[i] > threshold:
  208. break
  209. self.silence_acc_ms += window_size_ms
  210. # try not to split right before speech
  211. self.silence_acc_ms = self.silence_acc_ms // 2
  212. self.check_silence_acc(tgt_lang)
  213. self.speech_acc_ms += chunk_size_ms
  214. self.process_speech(segment, tgt_lang)
  215. else:
  216. self.speech_acc_ms += chunk_size_ms
  217. self.debug_log("======== got speech chunk")
  218. consecutive_silence_decay = True
  219. self.decay_silence_acc_ms()
  220. self.process_speech(segment, tgt_lang)
  221. if not consecutive_silence_decay:
  222. self.consecutive_silence_decay_count = 0
  223. def debug_write_wav(
  224. self, chunk: np.ndarray[Any, Any], finished: bool = False
  225. ) -> None:
  226. if self.test_input_segments_wav is not None:
  227. self.test_input_segments_wav.seek(0, SEEK_END)
  228. self.test_input_segments_wav.write(chunk)
  229. if finished:
  230. MODEL_SAMPLE_RATE = 16_000
  231. debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
  232. self.test_input_segments_wav = soundfile.SoundFile(
  233. Path(self.test_input_segments_wav.name).parent
  234. / f"{debug_ts}_test_input_segments.wav",
  235. mode="w+",
  236. format="WAV",
  237. samplerate=MODEL_SAMPLE_RATE,
  238. channels=1,
  239. )
  240. class SileroVADAgent(SpeechToSpeechAgent): # type: ignore
  241. def __init__(self, args: Namespace) -> None:
  242. super().__init__(args)
  243. self.chunk_size_samples = args.chunk_size_samples
  244. self.args = args
  245. @staticmethod
  246. def add_args(parser: ArgumentParser) -> None:
  247. parser.add_argument(
  248. "--window-size-samples",
  249. default=512, # sampling_rate // 1000 * 32 => 32 ms at 16000 sample rate
  250. type=int,
  251. help="Window size for passing samples to VAD",
  252. )
  253. parser.add_argument(
  254. "--chunk-size-samples",
  255. default=5120, # sampling_rate // 1000 * 320 => 320 ms at 16000 sample rate
  256. type=int,
  257. help="Chunk size for passing samples to model",
  258. )
  259. parser.add_argument(
  260. "--silence-limit-ms",
  261. default=700,
  262. type=int,
  263. help="send EOS to the input_queue after this amount of silence",
  264. )
  265. parser.add_argument(
  266. "--speech-soft-limit-ms",
  267. default=12_000, # after 15s, increase the speech threshold
  268. type=int,
  269. help="after this amount of speech, decrease the speech threshold (segment more aggressively)",
  270. )
  271. parser.add_argument(
  272. "--init-speech-prob",
  273. default=0.15,
  274. type=float,
  275. help="Increase the initial speech probability threshold by this much at the start of speech",
  276. )
  277. parser.add_argument(
  278. "--debug",
  279. default=False,
  280. type=bool,
  281. help="Enable debug logs",
  282. )
  283. def build_states(self) -> SileroVADStates:
  284. return SileroVADStates(self.args)
  285. def policy(self, states: SileroVADStates) -> Action:
  286. states.debug_log(
  287. f"queue size: {states.input_queue.qsize()}, input_chunk size: {len(states.input_chunk)}"
  288. )
  289. content: np.ndarray[Any, Any] = np.empty(0, dtype=np.int16)
  290. is_finished = states.source_finished
  291. tgt_lang = None
  292. while not states.input_queue.empty():
  293. chunk = states.input_queue.get_nowait()
  294. states.input_queue.task_done()
  295. if tgt_lang is None:
  296. tgt_lang = chunk.tgt_lang
  297. content = np.concatenate((content, chunk.content))
  298. states.debug_write_wav(content)
  299. if len(content) == 0: # empty queue
  300. if not states.source_finished:
  301. return ReadAction()
  302. else:
  303. # NOTE: this should never happen, this logic is a safeguard
  304. segment = EmptySegment(finished=True)
  305. else:
  306. segment = SpeechSegment(
  307. content=content.tolist(),
  308. finished=is_finished,
  309. tgt_lang=tgt_lang,
  310. )
  311. return WriteAction(segment, finished=is_finished)
  312. @classmethod
  313. def from_args(cls, args: Namespace, **kwargs: None) -> SileroVADAgent:
  314. return cls(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...