Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

mutox_text.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import sys
  8. import torch
  9. from seamless_communication.toxicity.mutox.loader import load_mutox_model
  10. from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
  11. import logging
  12. logging.basicConfig(
  13. level=logging.INFO,
  14. format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
  15. )
  16. CPU_DEVICE = torch.device("cpu")
  17. def main() -> None:
  18. parser = argparse.ArgumentParser(
  19. description="Mutox Text will compute a toxicity score for each sentence it is passed."
  20. )
  21. parser.add_argument(
  22. "lang",
  23. type=str,
  24. help="Language of the input text, nllb format with script.",
  25. )
  26. parser.add_argument(
  27. "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
  28. )
  29. parser.add_argument(
  30. "output", nargs="?", type=argparse.FileType("w"), default=sys.stdout
  31. )
  32. parser.add_argument(
  33. "--batch_size",
  34. type=int,
  35. help="Inference batch size.",
  36. default=4,
  37. )
  38. parser.add_argument(
  39. "--device",
  40. type=str,
  41. help="name of the device to use with torch.",
  42. required=False,
  43. )
  44. args, _unknown = parser.parse_known_args()
  45. if args.device is not None:
  46. device = torch.device(args.device)
  47. dtype = torch.float32
  48. if device.type == "cuda":
  49. dtype = torch.float16
  50. elif torch.cuda.is_available():
  51. device = torch.device("cuda:0")
  52. dtype = torch.float16
  53. else:
  54. device = torch.device("cpu")
  55. dtype = torch.float32
  56. t2vec_model = TextToEmbeddingModelPipeline(
  57. encoder="text_sonar_basic_encoder",
  58. tokenizer="text_sonar_basic_encoder",
  59. device=device,
  60. )
  61. classifier = load_mutox_model(
  62. "mutox",
  63. device=device,
  64. dtype=dtype,
  65. ).eval()
  66. def write_result(batch):
  67. emb = t2vec_model.predict(batch, source_lang=args.lang)
  68. scores = classifier(emb.half())
  69. for s, t in zip(scores, batch):
  70. print(t, s.item(), sep="\t", file=args.output)
  71. with torch.inference_mode():
  72. print("text", "score", sep="\t", file=args.output)
  73. batch = []
  74. for line in args.input:
  75. batch.append(line.rstrip())
  76. if len(batch) >= args.batch_size:
  77. write_result(batch)
  78. batch = []
  79. if len(batch):
  80. write_result(batch)
  81. if __name__ == "__main__":
  82. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...