Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

SampleGeneratorFace.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
  1. import multiprocessing
  2. import pickle
  3. import time
  4. import traceback
  5. import cv2
  6. import numpy as np
  7. from core import mplib
  8. from core.joblib import SubprocessGenerator, ThisThreadGenerator
  9. from facelib import LandmarksProcessor
  10. from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
  11. SampleType)
  12. '''
  13. arg
  14. output_sample_types = [
  15. [SampleProcessor.TypeFlags, size, (optional) {} opts ] ,
  16. ...
  17. ]
  18. '''
  19. class SampleGeneratorFace(SampleGeneratorBase):
  20. def __init__ (self, samples_path, debug=False, batch_size=1,
  21. random_ct_samples_path=None,
  22. sample_process_options=SampleProcessor.Options(),
  23. output_sample_types=[],
  24. add_sample_idx=False,
  25. generators_count=4,
  26. raise_on_no_data=True,
  27. **kwargs):
  28. super().__init__(debug, batch_size)
  29. self.sample_process_options = sample_process_options
  30. self.output_sample_types = output_sample_types
  31. self.add_sample_idx = add_sample_idx
  32. if self.debug:
  33. self.generators_count = 1
  34. else:
  35. self.generators_count = max(1, generators_count)
  36. samples = SampleLoader.load (SampleType.FACE, samples_path)
  37. self.samples_len = len(samples)
  38. self.initialized = False
  39. if self.samples_len == 0:
  40. if raise_on_no_data:
  41. raise ValueError('No training data provided.')
  42. else:
  43. return
  44. index_host = mplib.IndexHost(self.samples_len)
  45. if random_ct_samples_path is not None:
  46. ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
  47. ct_index_host = mplib.IndexHost( len(ct_samples) )
  48. else:
  49. ct_samples = None
  50. ct_index_host = None
  51. pickled_samples = pickle.dumps(samples, 4)
  52. ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
  53. if self.debug:
  54. self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
  55. else:
  56. self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
  57. for i in range(self.generators_count) ]
  58. SubprocessGenerator.start_in_parallel( self.generators )
  59. self.generator_counter = -1
  60. self.initialized = True
  61. #overridable
  62. def is_initialized(self):
  63. return self.initialized
  64. def __iter__(self):
  65. return self
  66. def __next__(self):
  67. if not self.initialized:
  68. return []
  69. self.generator_counter += 1
  70. generator = self.generators[self.generator_counter % len(self.generators) ]
  71. return next(generator)
  72. def batch_func(self, param ):
  73. pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
  74. samples = pickle.loads(pickled_samples)
  75. ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
  76. bs = self.batch_size
  77. while True:
  78. batches = None
  79. indexes = index_host.multi_get(bs)
  80. ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
  81. t = time.time()
  82. for n_batch in range(bs):
  83. sample_idx = indexes[n_batch]
  84. sample = samples[sample_idx]
  85. ct_sample = None
  86. if ct_samples is not None:
  87. ct_sample = ct_samples[ct_indexes[n_batch]]
  88. try:
  89. x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
  90. except:
  91. raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
  92. if batches is None:
  93. batches = [ [] for _ in range(len(x)) ]
  94. if self.add_sample_idx:
  95. batches += [ [] ]
  96. i_sample_idx = len(batches)-1
  97. for i in range(len(x)):
  98. batches[i].append ( x[i] )
  99. if self.add_sample_idx:
  100. batches[i_sample_idx].append (sample_idx)
  101. yield [ np.array(batch) for batch in batches]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...