Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import logging
  8. from fairseq2.assets import asset_store, download_manager
  9. from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
  10. SeamlessQualityScorer as SeamlessQualityScorer,
  11. )
  12. from seamless_communication.streaming.agents.seamless_s2st import SeamlessS2STAgent
  13. from seamless_communication.streaming.agents.seamless_streaming_s2st import (
  14. SeamlessStreamingS2STAgent,
  15. )
  16. from seamless_communication.streaming.agents.seamless_streaming_s2t import (
  17. SeamlessStreamingS2TAgent,
  18. )
  19. from simuleval.cli import evaluate
  20. logging.basicConfig(
  21. level=logging.INFO,
  22. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  23. )
  24. logger = logging.getLogger(__name__)
  25. def main() -> None:
  26. parser = argparse.ArgumentParser(
  27. add_help=False,
  28. description="Streaming evaluation of Seamless UnitY models",
  29. conflict_handler="resolve",
  30. )
  31. parser.add_argument(
  32. "--task",
  33. choices=["s2st", "s2tt", "asr"],
  34. required=True,
  35. type=str,
  36. help="Target language to translate/transcribe into.",
  37. )
  38. parser.add_argument(
  39. "--expressive",
  40. action="store_true",
  41. default=False,
  42. help="Expressive streaming S2ST inference",
  43. )
  44. args, _ = parser.parse_known_args()
  45. model_configs = dict(
  46. source_segment_size=320,
  47. device="cuda:0",
  48. dtype="fp16",
  49. min_starting_wait_w2vbert=192,
  50. decision_threshold=0.5,
  51. no_early_stop=True,
  52. max_len_a=0,
  53. max_len_b=100,
  54. )
  55. eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
  56. if args.task == "s2st":
  57. model_configs["min_unit_chunk_size"] = 50
  58. eval_configs["latency_metrics"] = "StartOffset EndOffset"
  59. if args.expressive:
  60. agent_class = SeamlessS2STAgent
  61. else:
  62. agent_class = SeamlessStreamingS2STAgent
  63. elif args.task in ["s2tt", "asr"]:
  64. assert args.expressive is False, "S2TT inference cannot be expressive."
  65. agent_class = SeamlessStreamingS2TAgent
  66. parser.add_argument(
  67. "--unity-model-name",
  68. type=str,
  69. help="Unity model name.",
  70. default="seamless_streaming_unity",
  71. )
  72. args, _ = parser.parse_known_args()
  73. asset_card = asset_store.retrieve_card(name=args.unity_model_name)
  74. tokenizer_uri = asset_card.field("tokenizer").as_uri()
  75. tokenizer_path = download_manager.download_tokenizer(
  76. tokenizer_uri, asset_card.name, force=False, progress=True
  77. )
  78. eval_configs["latency_metrics"] = "AL LAAL"
  79. eval_configs["eval_latency_unit"] = "spm"
  80. eval_configs["eval_latency_spm_model"] = tokenizer_path
  81. base_config = dict(
  82. dataloader="fairseq2_s2tt",
  83. dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
  84. )
  85. evaluate(agent_class, {**base_config, **model_configs, **eval_configs}, parser)
  86. if __name__ == "__main__":
  87. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...