Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. import os
  2. import os.path as osp
  3. import sys
  4. import time
  5. from collections import defaultdict
  6. import matplotlib
  7. import numpy as np
  8. import soundfile as sf
  9. import torch
  10. from torch import nn
  11. import jiwer
  12. import matplotlib.pylab as plt
  13. def calc_wer(target, pred, ignore_indexes=[0]):
  14. target_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(target)))))
  15. pred_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(pred)))))
  16. target_str = ' '.join(target_chars)
  17. pred_str = ' '.join(pred_chars)
  18. error = jiwer.wer(target_str, pred_str)
  19. return error
  20. def drop_duplicated(chars):
  21. ret_chars = [chars[0]]
  22. for prev, curr in zip(chars[:-1], chars[1:]):
  23. if prev != curr:
  24. ret_chars.append(curr)
  25. return ret_chars
  26. def build_criterion(critic_params={}):
  27. criterion = {
  28. "ce": nn.CrossEntropyLoss(ignore_index=-1),
  29. "ctc": torch.nn.CTCLoss(**critic_params.get('ctc', {})),
  30. }
  31. return criterion
  32. def get_data_path_list(train_path=None, val_path=None):
  33. base_path = os.path.dirname(train_path)
  34. base_path_val = os.path.dirname(val_path)
  35. if train_path is None:
  36. train_path = "Data/train_list.txt"
  37. if val_path is None:
  38. val_path = "Data/val_list.txt"
  39. with open(train_path, 'r', encoding="utf8") as f:
  40. train_list = f.readlines()
  41. with open(val_path, 'r', encoding="utf8") as f:
  42. val_list = f.readlines()
  43. n_train_list = []
  44. n_val_list = []
  45. for i in train_list:
  46. path = i.split("|")[0]
  47. path = os.path.join(base_path, path)
  48. phon = i.split("|")[1]
  49. voice = i.split("|")[2]
  50. n_train_list.append(f"{path}|{phon}|{voice}")
  51. for i in val_list:
  52. path = i.split("|")[0]
  53. path = os.path.join(base_path, path)
  54. phon = i.split("|")[1]
  55. voice = i.split("|")[2]
  56. n_val_list.append(f"{path}|{phon}|{voice}")
  57. return n_train_list, n_val_list
  58. def plot_image(image):
  59. fig, ax = plt.subplots(figsize=(10, 2))
  60. im = ax.imshow(image, aspect="auto", origin="lower",
  61. interpolation='none')
  62. fig.canvas.draw()
  63. plt.close()
  64. return fig
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...