Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataloader.py 11 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import json
  7. import logging
  8. from dataclasses import dataclass
  9. from typing import Any, Dict, Iterable, List, Optional, Tuple
  10. import numpy as np
  11. import torch
  12. import torchaudio
  13. from datasets import Dataset
  14. from datasets.distributed import split_dataset_by_node
  15. from fairseq2.data.text import TextTokenEncoder
  16. from fairseq2.models.nllb import NllbTokenizer
  17. from fairseq2.data.audio import WaveformToFbankConverter
  18. from torch import Tensor
  19. from torch.nn.functional import pad as pad_tensor
  20. from torch.utils.data import DataLoader
  21. from seamless_communication.datasets.datatypes import LangPairSample
  22. from seamless_communication.models.unity.unit_tokenizer import (
  23. UnitTokenEncoder,
  24. UnitTokenizer,
  25. )
  26. logger = logging.getLogger(__name__)
  27. @dataclass
  28. class SeqsBatch:
  29. src_tokens: Optional[Tensor]
  30. src_lengths: Optional[Tensor]
  31. target_tokens: Optional[Tensor]
  32. prev_output_tokens: Optional[Tensor]
  33. target_lengths: Optional[Tensor]
  34. def __del__(self) -> None:
  35. """Explicitly delete tensors
  36. to force GPU memory cleanup"""
  37. for tensor in [
  38. self.src_tokens,
  39. self.src_lengths,
  40. self.target_tokens,
  41. self.prev_output_tokens,
  42. self.target_lengths,
  43. ]:
  44. if tensor is not None:
  45. del tensor
  46. @dataclass
  47. class MultimodalSeqsBatch:
  48. speech_to_text: SeqsBatch
  49. text_to_units: SeqsBatch
  50. def __del__(self) -> None:
  51. del self.speech_to_text
  52. del self.text_to_units
  53. @dataclass
  54. class BatchingConfig:
  55. fbank_feats_pad_idx: int = 0
  56. """The pad index to use in fbanks batching."""
  57. batch_size: int = 5
  58. """Fixed batch size to use"""
  59. max_audio_length_sec: float = 15.0
  60. """ Drop samples with source audio sample length above the threshold."""
  61. rank: int = 0
  62. """The rank of this worker in the process group."""
  63. world_size: int = 1
  64. """The world size of the process group."""
  65. num_workers: int = 2
  66. """Parallelism in dataset preparation."""
  67. float_dtype: torch.dtype = torch.float16
  68. """Select between fp16/fp32 for float tensors """
  69. def worker_init_fn(worker_id: int) -> None:
  70. np.random.seed(np.random.get_state()[1][0] + worker_id) # type: ignore
  71. class UnitYDataLoader:
  72. SAMPLE_RATE = 16_000
  73. def __init__(
  74. self,
  75. text_tokenizer: NllbTokenizer,
  76. unit_tokenizer: UnitTokenizer,
  77. dataset_manifest_path: str,
  78. batching_config: BatchingConfig,
  79. max_src_tokens_per_batch: int = 100000
  80. ):
  81. self.text_tokenizer = text_tokenizer
  82. self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
  83. self.unit_tokenizer = unit_tokenizer
  84. self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
  85. self.batching_config = batching_config
  86. self._fbank_extract_params = {
  87. "num_mel_bins": 80,
  88. "waveform_scale": 32768,
  89. "channel_last": True,
  90. "standardize": True,
  91. "device": torch.device("cpu"),
  92. "dtype": self.batching_config.float_dtype,
  93. }
  94. self.dataset = self._load_manifest(dataset_manifest_path)
  95. self.max_src_tokens_per_batch = max_src_tokens_per_batch
  96. def get_dataloader(self) -> DataLoader[SeqsBatch]:
  97. subset = split_dataset_by_node(
  98. self.dataset,
  99. rank=self.batching_config.rank,
  100. world_size=self.batching_config.world_size,
  101. )
  102. data_loader = DataLoader(
  103. dataset=subset,
  104. batch_size=self.batching_config.batch_size,
  105. shuffle=True,
  106. num_workers=self.batching_config.num_workers,
  107. collate_fn=self._prepare_batch,
  108. worker_init_fn=worker_init_fn,
  109. )
  110. return data_loader
  111. def __iter__(self) -> Iterable[MultimodalSeqsBatch]:
  112. return self.get_dataloader().__iter__()
  113. def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
  114. wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
  115. assert (
  116. int(sample_rate) == self.SAMPLE_RATE
  117. ), f"sample != {self.SAMPLE_RATE}, please resample"
  118. assert len(wav.shape) in (1, 2)
  119. if len(wav.shape) == 1:
  120. wav = wav.unsqueeze(-1)
  121. elif wav.shape[0] <= 2: # channel is first, should be second
  122. wav = wav.transpose(0, 1)
  123. return WaveformToFbankConverter(**self._fbank_extract_params)( # type: ignore
  124. {
  125. "waveform": wav,
  126. "sample_rate": self.SAMPLE_RATE,
  127. }
  128. )["fbank"]
  129. def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor:
  130. """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
  131. target_lang = sample.target.lang
  132. if target_lang not in self.text_encoders_per_lang:
  133. self.text_encoders_per_lang[target_lang] = (
  134. self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
  135. )
  136. tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
  137. eos_idx = self.text_tokenizer.vocab_info.eos_idx
  138. tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
  139. return tokens
  140. def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]:
  141. """Expected sequence is [<eos>, <lang_tok> , ..unit tokens.., <eos>]"""
  142. if sample.target.units is None:
  143. return None
  144. target_lang = sample.target.lang
  145. if target_lang not in self.unit_encoders_per_lang:
  146. self.unit_encoders_per_lang[target_lang] = (
  147. self.unit_tokenizer.create_encoder(lang=target_lang)
  148. )
  149. tokens = self.unit_encoders_per_lang[target_lang](
  150. torch.LongTensor(sample.target.units).unsqueeze(0)
  151. )
  152. eos_idx = self.unit_tokenizer.vocab_info.eos_idx
  153. tokens = torch.concat([tokens.squeeze(0), torch.LongTensor([eos_idx])])
  154. return tokens
  155. def _batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
  156. padding_size = max(tensor.shape[0] for tensor in tensors)
  157. dims = len(tensors[0].shape)
  158. padded_tensors = []
  159. for tensor in tensors:
  160. padding = [0] * 2 * dims
  161. padding[-1] = padding_size - tensor.shape[0]
  162. padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
  163. return torch.stack([tensor for tensor in padded_tensors], dim=0)
  164. def _is_long_src_audio(self, sample: LangPairSample) -> bool:
  165. # HACK:: causes errored audios to be excluded but this is difficult to follow
  166. try:
  167. wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
  168. length_s: float = max(wav.shape) / sample_rate
  169. return length_s > self.batching_config.max_audio_length_sec
  170. except:
  171. logger.exception(f"Failed to load sample path: {sample.source.audio_local_path}")
  172. return True
  173. def _drop_overflow_samples(
  174. self, samples_with_fbanks: List[Tuple[LangPairSample, torch.Tensor]]
  175. ) -> List[Tuple[LangPairSample, torch.Tensor]]:
  176. # filter by src_tokens length (reverse)
  177. samples_with_fbanks = sorted(
  178. samples_with_fbanks, key=lambda sb: -sb[1].shape[0]
  179. )
  180. bwd = samples_with_fbanks[0][1].shape[0]
  181. max_samples_for_batch = max(1, self.max_src_tokens_per_batch // bwd)
  182. if max_samples_for_batch < len(samples_with_fbanks):
  183. samples_with_fbanks = samples_with_fbanks[:max_samples_for_batch]
  184. return samples_with_fbanks
  185. def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
  186. samples = [LangPairSample.from_json(sample) for sample in raw_samples]
  187. # input speech
  188. # - filter long audio samples
  189. filtered_samples = [
  190. sample for sample in samples if not self._is_long_src_audio(sample)
  191. ]
  192. samples = (
  193. filtered_samples if filtered_samples else [samples[0]]
  194. ) # keep at least one sample
  195. with_fbanks = [(sample, self._get_source_fbank(sample)) for sample in samples]
  196. # - filter NaNs in fbanks
  197. filtered = [
  198. (sample, fbank)
  199. for sample, fbank in with_fbanks
  200. if not fbank.isnan().any().item()
  201. ]
  202. filtered = self._drop_overflow_samples(filtered)
  203. samples = [sample for sample, _ in filtered]
  204. src_tokens_list = [src_tokens for _, src_tokens in filtered]
  205. assert len(samples) > 0
  206. src_tokens = self._batch_tensors(
  207. src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
  208. ).to(self.batching_config.float_dtype)
  209. src_lengths = torch.LongTensor(
  210. [src_tokens.shape[0] for src_tokens in src_tokens_list]
  211. )
  212. # output text
  213. text_tokens_list = [
  214. self._get_tokenized_target_text(sample) for sample in samples
  215. ]
  216. text_pad_idx = self.text_tokenizer.vocab_info.pad_idx
  217. prev_outputs_tokens = self._batch_tensors(
  218. [tokens[:-1] for tokens in text_tokens_list], pad_value=text_pad_idx
  219. )
  220. target_tokens = self._batch_tensors(
  221. [tokens[1:] for tokens in text_tokens_list], pad_value=text_pad_idx
  222. )
  223. tokens_lengths = torch.LongTensor(
  224. [tokens.shape[0] - 1 for tokens in text_tokens_list]
  225. )
  226. # output units
  227. units_list_raw = [self._get_tokenized_units(sample) for sample in samples]
  228. if None in units_list_raw:
  229. prev_outputs_units = None
  230. target_units = None
  231. units_lengths = None
  232. else:
  233. units_list: List[Tensor] = [
  234. value for value in units_list_raw if value is not None
  235. ]
  236. units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
  237. prev_outputs_units = self._batch_tensors(
  238. [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
  239. )
  240. target_units = self._batch_tensors(
  241. [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
  242. )
  243. units_lengths = torch.LongTensor(
  244. [tokens.shape[0] - 1 for tokens in units_list]
  245. )
  246. return MultimodalSeqsBatch(
  247. speech_to_text=SeqsBatch(
  248. src_tokens=src_tokens,
  249. src_lengths=src_lengths,
  250. target_tokens=target_tokens,
  251. prev_output_tokens=prev_outputs_tokens,
  252. target_lengths=tokens_lengths,
  253. ),
  254. text_to_units=SeqsBatch(
  255. src_tokens=None,
  256. src_lengths=None,
  257. target_tokens=target_units,
  258. prev_output_tokens=prev_outputs_units,
  259. target_lengths=units_lengths,
  260. ),
  261. )
  262. def _load_manifest(self, dataset_manifest_path: str) -> Dataset:
  263. with open(dataset_manifest_path) as fp_in:
  264. dataset = [json.loads(line) for line in fp_in]
  265. return Dataset.from_list(dataset)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...