Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  1. from mrcnn.utils import Dataset
  2. from os import listdir
  3. from pprint import pprint
  4. import numpy as np
  5. from xml.etree import ElementTree
  6. def load_train_val(data_config):
  7. data_path = data_config['path']
  8. ds_gen = lambda: KangarooDataset(train_size=data_config['train_size'])
  9. train_set = ds_gen()
  10. train_set.load_data(data_path)
  11. train_set.prepare()
  12. val_set = ds_gen()
  13. val_set.load_data(data_path, is_train=False)
  14. val_set.prepare()
  15. print(f'train-size: ', len(train_set.image_ids))
  16. print(f'val-size: ', len(val_set.image_ids))
  17. return train_set, val_set
  18. def get_xml_int(xml_element, name):
  19. str_value = xml_element.findtext(name)
  20. if not str_value:
  21. inner_values = {x.tag : x.text for x in xml_element}
  22. raise ValueError(
  23. f'Invalid value for type <int> when searching for <{name}> in <{xml_element.tag}> which contains: {inner_values}'
  24. )
  25. return int(str_value)
  26. def extract_boxes(annotation_path):
  27. xml_tree = ElementTree.parse(annotation_path)
  28. annotation = xml_tree.getroot()
  29. boxes = []
  30. for box in annotation.iter('bndbox'):
  31. xmin = get_xml_int(box, 'xmin')
  32. ymin = get_xml_int(box, 'ymin')
  33. xmax = get_xml_int(box, 'xmax')
  34. ymax = get_xml_int(box, 'ymax')
  35. coordinates = [xmin, ymin, xmax, ymax]
  36. boxes.append(coordinates)
  37. width = get_xml_int(annotation, './/size/width')
  38. height = get_xml_int(annotation, './/size/height')
  39. return boxes, width, height
  40. class KangarooDataset(Dataset):
  41. def __init__(self, train_size, debug = False, *args, **kwargs):
  42. super().__init__(*args, **kwargs)
  43. self.debug = debug
  44. self.debug_limit = 2
  45. self.bad_images = ['00090']
  46. self.source = 'dataset'
  47. self.train_size = train_size
  48. def load_data(self, dataset_dir, is_train=True):
  49. self.add_class(self.source, 1, 'kangaroo')
  50. images_dir = dataset_dir + '/images/'
  51. annotations_dir = dataset_dir + '/annots/'
  52. for filename in listdir(images_dir):
  53. # get image id
  54. image_id = filename[:-4] # skipping file extension '.jpg'
  55. if self.debug and int(image_id) > self.debug_limit: break
  56. # skip bad ones
  57. if image_id in self.bad_images: continue
  58. image_path = images_dir + filename
  59. annotation_path = annotations_dir + image_id + '.xml'
  60. # training/validation
  61. if is_train and int(image_id) >= self.train_size: continue
  62. if not is_train and int(image_id) < self.train_size: continue
  63. # add to ds
  64. self.add_image(self.source, image_id=image_id, path=image_path, annotation=annotation_path)
  65. if self.debug:
  66. print('-- Image info --')
  67. pprint(self.image_info)
  68. def extract_boxes(self, annotation_path):
  69. boxes, width, height = extract_boxes(annotation_path)
  70. if self.debug:
  71. print(f'width: {width}, height: {height}, boxes: ')
  72. pprint(boxes)
  73. return boxes, width, height
  74. def load_mask(self,image_id):
  75. info = self.image_info[image_id]
  76. annotation_path = info['annotation']
  77. boxes, w, h = self.extract_boxes(annotation_path)
  78. class_ids = []
  79. masks = np.zeros([h, w, len(boxes)], dtype=np.uint8)
  80. for i, (xmin, ymin, xmax, ymax) in enumerate(boxes):
  81. masks[ymin:ymax, xmin:xmax, i] = 1
  82. class_ids.append(self.class_names.index('kangaroo'))
  83. return masks, np.asarray(class_ids, dtype=np.int32)
  84. def image_reference(self, image_id):
  85. info = self.image_info[image_id]
  86. return info['path']
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...