Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detokenizer.py 2.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. from argparse import ArgumentParser, Namespace
  8. from typing import Any, Dict
  9. from simuleval.agents import TextToTextAgent
  10. from simuleval.agents.actions import Action, ReadAction, WriteAction
  11. from seamless_communication.streaming.agents.common import (
  12. AgentStates,
  13. NoUpdateTargetMixin,
  14. )
  15. from seamless_communication.streaming.agents.online_text_decoder import (
  16. UnitYTextDecoderOutput,
  17. )
  18. from simuleval.data.segments import Segment, EmptySegment
  19. class DetokenizerAgent(NoUpdateTargetMixin, TextToTextAgent): # type: ignore
  20. def __init__(self, args: Namespace):
  21. super().__init__(args)
  22. self.detokenize_only = args.detokenize_only
  23. @classmethod
  24. def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> DetokenizerAgent:
  25. return cls(args)
  26. def add_args(parser: ArgumentParser) -> None:
  27. parser.add_argument(
  28. "--detokenize-only",
  29. action="store_true",
  30. default=True,
  31. help="Run detokenization without waiting for a new token.",
  32. )
  33. def policy(self, states: AgentStates) -> Action:
  34. possible_full_words = self.decode(" ".join([x for x in states.source]))
  35. if self.detokenize_only and len(states.source) > 0:
  36. states.source = []
  37. if len(possible_full_words) == 0 and not states.source_finished:
  38. return ReadAction()
  39. else:
  40. return WriteAction(possible_full_words, states.source_finished)
  41. if states.source_finished:
  42. return WriteAction(possible_full_words, True)
  43. elif len(possible_full_words.split()) > 1:
  44. full_word = possible_full_words.split()[0]
  45. states.source = states.source[-1:]
  46. return WriteAction(full_word, finished=False)
  47. else:
  48. return ReadAction()
  49. def decode(self, x: str) -> str:
  50. return x.replace(" ", "").replace("\u2581", " ").strip()
  51. class UnitYDetokenizerAgentStates(AgentStates):
  52. def update_source(self, segment: Segment) -> None:
  53. """
  54. Extract tokens from UnitYTextDecoderOutput
  55. """
  56. self.source_finished = segment.finished
  57. if isinstance(segment, EmptySegment):
  58. return
  59. # TextSegment
  60. segment_content: UnitYTextDecoderOutput = segment.content
  61. token = segment_content.tokens
  62. self.source += token
  63. class UnitYDetokenizerAgent(DetokenizerAgent):
  64. def build_states(self) -> UnitYDetokenizerAgentStates:
  65. return UnitYDetokenizerAgentStates()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...