Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

SampleGeneratorImageTemporal.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  1. import traceback
  2. import numpy as np
  3. import cv2
  4. from utils import iter_utils
  5. from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
  6. '''
  7. output_sample_types = [
  8. [SampleProcessor.TypeFlags, size, (optional)random_sub_size] ,
  9. ...
  10. ]
  11. '''
  12. class SampleGeneratorImageTemporal(SampleGeneratorBase):
  13. def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs):
  14. super().__init__(samples_path, debug, batch_size)
  15. self.temporal_image_count = temporal_image_count
  16. self.sample_process_options = sample_process_options
  17. self.output_sample_types = output_sample_types
  18. self.samples = SampleHost.load (SampleType.IMAGE, self.samples_path)
  19. self.generator_samples = [ self.samples ]
  20. self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \
  21. [iter_utils.SubprocessGenerator ( self.batch_func, 0 )]
  22. self.generator_counter = -1
  23. def __iter__(self):
  24. return self
  25. def __next__(self):
  26. self.generator_counter += 1
  27. generator = self.generators[self.generator_counter % len(self.generators) ]
  28. return next(generator)
  29. def batch_func(self, generator_id):
  30. samples = self.generator_samples[generator_id]
  31. samples_len = len(samples)
  32. if samples_len == 0:
  33. raise ValueError('No training data provided.')
  34. mult_max = 4
  35. samples_sub_len = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
  36. if samples_sub_len <= 0:
  37. raise ValueError('Not enough samples to fit temporal line.')
  38. shuffle_idxs = []
  39. while True:
  40. batches = None
  41. for n_batch in range(self.batch_size):
  42. if len(shuffle_idxs) == 0:
  43. shuffle_idxs = [ *range(samples_sub_len) ]
  44. np.random.shuffle (shuffle_idxs)
  45. idx = shuffle_idxs.pop()
  46. temporal_samples = []
  47. mult = np.random.randint(mult_max)+1
  48. for i in range( self.temporal_image_count ):
  49. sample = samples[ idx+i*mult ]
  50. try:
  51. temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0]
  52. except:
  53. raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
  54. if batches is None:
  55. batches = [ [] for _ in range(len(temporal_samples)) ]
  56. for i in range(len(temporal_samples)):
  57. batches[i].append ( temporal_samples[i] )
  58. yield [ np.array(batch) for batch in batches]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...