Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dual_vocoder_agent.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. import copy
  8. import logging
  9. from argparse import ArgumentParser, Namespace
  10. from typing import Dict, Any
  11. from simuleval.agents import TextToSpeechAgent
  12. from seamless_communication.streaming.agents.common import AgentStates
  13. from simuleval.data.segments import Segment
  14. from simuleval.agents.actions import Action
  15. from seamless_communication.streaming.agents.pretssel_vocoder import (
  16. PretsselVocoderAgent,
  17. )
  18. from seamless_communication.streaming.agents.online_vocoder import VocoderAgent
  19. logging.basicConfig(
  20. level=logging.INFO,
  21. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  22. )
  23. logger = logging.getLogger(__name__)
  24. class DualVocoderStates(AgentStates):
  25. def __init__(
  26. self, vocoder_states: AgentStates, expr_vocoder_states: AgentStates
  27. ) -> None:
  28. self.vocoder_states = vocoder_states
  29. self.expr_vocoder_states = expr_vocoder_states
  30. self.config: Dict[str, Any] = {}
  31. @property
  32. def target_finished(self): # type: ignore
  33. return (
  34. self.vocoder_states.target_finished
  35. or self.expr_vocoder_states.target_finished
  36. )
  37. def reset(self) -> None:
  38. self.vocoder_states.reset()
  39. self.expr_vocoder_states.reset()
  40. self.config = {}
  41. def update_source(self, segment: Segment) -> None:
  42. self.vocoder_states.update_config(segment.config)
  43. self.vocoder_states.update_source(segment)
  44. self.expr_vocoder_states.update_config(segment.config)
  45. self.expr_vocoder_states.update_source(segment)
  46. def update_target(self, segment: Segment) -> None:
  47. self.vocoder_states.update_target(segment)
  48. self.expr_vocoder_states.update_target(segment)
  49. class DualVocoderAgent(TextToSpeechAgent): # type: ignore
  50. def __init__(
  51. self,
  52. args: Namespace,
  53. vocoder: VocoderAgent,
  54. expr_vocoder: PretsselVocoderAgent,
  55. ) -> None:
  56. self.vocoder = vocoder
  57. self.expr_vocoder = expr_vocoder
  58. super().__init__(args)
  59. self.expressive = args.expressive
  60. def build_states(self) -> DualVocoderStates:
  61. return DualVocoderStates(
  62. self.vocoder.build_states(), self.expr_vocoder.build_states()
  63. )
  64. @classmethod
  65. def add_args(cls, parser: ArgumentParser) -> None:
  66. PretsselVocoderAgent.add_args(parser)
  67. VocoderAgent.add_args(parser)
  68. parser.add_argument(
  69. "--expr-vocoder-name",
  70. type=str,
  71. required=True,
  72. help="expressive vocoder name - vocoder_pretssel or vocoder_pretssel_16khz",
  73. )
  74. parser.add_argument(
  75. "--expressive",
  76. action="store_true",
  77. help="Whether to use expressive vocoder (overridable in segment.config)",
  78. )
  79. @classmethod
  80. def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> DualVocoderAgent:
  81. vocoder = VocoderAgent.from_args(args)
  82. expr_args = copy.deepcopy(args)
  83. expr_args.vocoder_name = args.expr_vocoder_name
  84. expr_vocoder = PretsselVocoderAgent.from_args(expr_args)
  85. return cls(args, vocoder, expr_vocoder)
  86. def policy(self, states: AgentStates) -> Action:
  87. expressive = self.expressive
  88. if states.config is not None and "expressive" in states.config:
  89. expressive = states.config["expressive"]
  90. if expressive:
  91. states.expr_vocoder_states.upstream_states = states.upstream_states
  92. action = self.expr_vocoder.policy(states.expr_vocoder_states)
  93. if len(states.expr_vocoder_states.source) == 0:
  94. states.vocoder_states.source = []
  95. else:
  96. action = self.vocoder.policy(states.vocoder_states)
  97. if len(states.vocoder_states.source) == 0:
  98. states.expr_vocoder_states.source = []
  99. return action
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...