Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

common.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. from contextlib import contextmanager
  7. from typing import Any, Generator, List, Optional, Union
  8. import torch
  9. from fairseq2.data import Collater
  10. from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
  11. from fairseq2.typing import DataType, Device
  12. from torch import Tensor
  13. # The default device that tests should use. Note that pytest can change it based
  14. # on the provided command line arguments.
  15. device = Device("cpu")
  16. def assert_close(
  17. a: Tensor,
  18. b: Union[Tensor, List[Any]],
  19. rtol: Optional[float] = None,
  20. atol: Optional[float] = None,
  21. ) -> None:
  22. """Assert that ``a`` and ``b`` are element-wise equal within a tolerance."""
  23. if not isinstance(b, Tensor):
  24. b = torch.tensor(b, device=device, dtype=a.dtype)
  25. torch.testing.assert_close(a, b, rtol=rtol, atol=atol) # type: ignore[attr-defined]
  26. def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
  27. """Assert that ``a`` and ``b`` are element-wise equal."""
  28. if not isinstance(b, Tensor):
  29. b = torch.tensor(b, device=device, dtype=a.dtype)
  30. torch.testing.assert_close(a, b, rtol=0, atol=0) # type: ignore[attr-defined]
  31. def assert_unit_close(
  32. a: Tensor,
  33. b: Union[Tensor, List[Any]],
  34. num_unit_tol: int = 1,
  35. percent_unit_tol: float = 0.0,
  36. ) -> None:
  37. """Assert two unit sequence are equal within a tolerance"""
  38. if not isinstance(b, Tensor):
  39. b = torch.tensor(b, device=device, dtype=a.dtype)
  40. assert (
  41. a.shape == b.shape
  42. ), f"Two shapes are different, one is {a.shape}, the other is {b.shape}"
  43. if percent_unit_tol > 0.0:
  44. num_unit_tol = int(percent_unit_tol * len(a))
  45. num_unit_diff = (a != b).sum()
  46. assert (
  47. num_unit_diff <= num_unit_tol
  48. ), f"The difference is beyond tolerance, {num_unit_diff} units are different, tolerance is {num_unit_tol}"
  49. def has_no_inf(a: Tensor) -> bool:
  50. """Return ``True`` if ``a`` has no positive or negative infinite element."""
  51. return not torch.any(torch.isinf(a))
  52. def has_no_nan(a: Tensor) -> bool:
  53. """Return ``True`` if ``a`` has no NaN element."""
  54. return not torch.any(torch.isnan(a))
  55. @contextmanager
  56. def tmp_rng_seed(device: Device, seed: int = 0) -> Generator[None, None, None]:
  57. """Set a temporary manual RNG seed.
  58. The RNG is reset to its original state once the block is exited.
  59. """
  60. device = Device(device)
  61. if device.type == "cuda":
  62. devices = [device]
  63. else:
  64. devices = []
  65. with torch.random.fork_rng(devices):
  66. torch.manual_seed(seed)
  67. yield
  68. def get_default_dtype() -> DataType:
  69. if device == Device("cpu"):
  70. dtype = torch.float32
  71. else:
  72. dtype = torch.float16
  73. return dtype
  74. def convert_to_collated_fbank(audio_dict: WaveformToFbankInput, dtype: DataType) -> Any:
  75. convert_to_fbank = WaveformToFbankConverter(
  76. num_mel_bins=80,
  77. waveform_scale=2**15,
  78. channel_last=True,
  79. standardize=True,
  80. device=device,
  81. dtype=dtype,
  82. )
  83. collater = Collater(pad_value=1)
  84. feat = collater(convert_to_fbank(audio_dict))["fbank"]
  85. return feat
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...