Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

online_vocoder.py 2.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. import logging
  8. from argparse import ArgumentParser, Namespace
  9. from typing import Any, Dict
  10. import torch
  11. from seamless_communication.models.vocoder.loader import load_vocoder_model
  12. from seamless_communication.streaming.agents.common import AgentStates
  13. from simuleval.agents import TextToSpeechAgent
  14. from simuleval.agents.actions import ReadAction, WriteAction
  15. from simuleval.data.segments import SpeechSegment
  16. logging.basicConfig(
  17. level=logging.INFO,
  18. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  19. )
  20. logger = logging.getLogger(__name__)
  21. class VocoderAgent(TextToSpeechAgent): # type: ignore
  22. def __init__(self, args: Namespace) -> None:
  23. super().__init__(args)
  24. logger.info(
  25. f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
  26. )
  27. self.vocoder = load_vocoder_model(
  28. args.vocoder_name, device=args.device, dtype=args.dtype
  29. )
  30. self.vocoder.eval()
  31. self.sample_rate = args.sample_rate
  32. self.tgt_lang = args.tgt_lang
  33. self.speaker_id = args.vocoder_speaker_id
  34. @torch.inference_mode()
  35. def policy(self, states: AgentStates) -> WriteAction:
  36. """
  37. The policy is always write if there are units
  38. """
  39. units = states.source
  40. if len(units) == 0 or len(units[0]) == 0:
  41. if states.source_finished:
  42. return WriteAction([], finished=True)
  43. else:
  44. return ReadAction()
  45. tgt_lang = states.tgt_lang if states.tgt_lang else self.tgt_lang
  46. u = units[0][0]
  47. wav = self.vocoder(u, tgt_lang, self.speaker_id, dur_prediction=False)
  48. states.source = []
  49. return WriteAction(
  50. SpeechSegment(
  51. content=wav[0][0].tolist(),
  52. finished=states.source_finished,
  53. sample_rate=self.sample_rate,
  54. tgt_lang=tgt_lang,
  55. ),
  56. finished=states.source_finished,
  57. )
  58. @classmethod
  59. def add_args(cls, parser: ArgumentParser) -> None:
  60. parser.add_argument(
  61. "--vocoder-name",
  62. type=str,
  63. help="Vocoder name.",
  64. default="vocoder_v2",
  65. )
  66. parser.add_argument(
  67. "--vocoder-speaker-id",
  68. type=int,
  69. required=False,
  70. default=-1,
  71. help="Vocoder speaker id",
  72. )
  73. @classmethod
  74. def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> VocoderAgent:
  75. return cls(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...