Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

inpaint_mask.py 4.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  1. import sys
  2. import modules.config
  3. import numpy as np
  4. import torch
  5. from extras.GroundingDINO.util.inference import default_groundingdino
  6. from extras.sam.predictor import SamPredictor
  7. from rembg import remove, new_session
  8. from segment_anything import sam_model_registry
  9. from segment_anything.utils.amg import remove_small_regions
  10. class SAMOptions:
  11. def __init__(self,
  12. # GroundingDINO
  13. dino_prompt: str = '',
  14. dino_box_threshold=0.3,
  15. dino_text_threshold=0.25,
  16. dino_erode_or_dilate=0,
  17. dino_debug=False,
  18. # SAM
  19. max_detections=2,
  20. model_type='vit_b'
  21. ):
  22. self.dino_prompt = dino_prompt
  23. self.dino_box_threshold = dino_box_threshold
  24. self.dino_text_threshold = dino_text_threshold
  25. self.dino_erode_or_dilate = dino_erode_or_dilate
  26. self.dino_debug = dino_debug
  27. self.max_detections = max_detections
  28. self.model_type = model_type
  29. def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
  30. """
  31. removes small disconnected regions and holes
  32. """
  33. fine_masks = []
  34. for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
  35. fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
  36. masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
  37. return torch.from_numpy(masks)
  38. def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
  39. sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]:
  40. dino_detection_count = 0
  41. sam_detection_count = 0
  42. sam_detection_on_mask_count = 0
  43. if image is None:
  44. return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
  45. if extras is None:
  46. extras = {}
  47. if 'image' in image:
  48. image = image['image']
  49. if mask_model != 'sam' or sam_options is None:
  50. result = remove(
  51. image,
  52. session=new_session(mask_model, **extras),
  53. only_mask=True,
  54. **extras
  55. )
  56. return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
  57. detections, boxes, logits, phrases = default_groundingdino(
  58. image=image,
  59. caption=sam_options.dino_prompt,
  60. box_threshold=sam_options.dino_box_threshold,
  61. text_threshold=sam_options.dino_text_threshold
  62. )
  63. H, W = image.shape[0], image.shape[1]
  64. boxes = boxes * torch.Tensor([W, H, W, H])
  65. boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
  66. boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
  67. sam_checkpoint = modules.config.download_sam_model(sam_options.model_type)
  68. sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint)
  69. sam_predictor = SamPredictor(sam)
  70. final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
  71. dino_detection_count = boxes.size(0)
  72. if dino_detection_count > 0:
  73. sam_predictor.set_image(image)
  74. if sam_options.dino_erode_or_dilate != 0:
  75. for index in range(boxes.size(0)):
  76. assert boxes.size(1) == 4
  77. boxes[index][0] -= sam_options.dino_erode_or_dilate
  78. boxes[index][1] -= sam_options.dino_erode_or_dilate
  79. boxes[index][2] += sam_options.dino_erode_or_dilate
  80. boxes[index][3] += sam_options.dino_erode_or_dilate
  81. if sam_options.dino_debug:
  82. from PIL import ImageDraw, Image
  83. debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
  84. draw = ImageDraw.Draw(debug_dino_image)
  85. for box in boxes.numpy():
  86. draw.rectangle(box.tolist(), fill="white")
  87. return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count
  88. transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
  89. masks, _, _ = sam_predictor.predict_torch(
  90. point_coords=None,
  91. point_labels=None,
  92. boxes=transformed_boxes,
  93. multimask_output=False,
  94. )
  95. masks = optimize_masks(masks)
  96. sam_detection_count = len(masks)
  97. if sam_options.max_detections == 0:
  98. sam_options.max_detections = sys.maxsize
  99. sam_objects = min(len(logits), sam_options.max_detections)
  100. for obj_ind in range(sam_objects):
  101. mask_tensor = masks[obj_ind][0]
  102. final_mask_tensor += mask_tensor
  103. sam_detection_on_mask_count += 1
  104. final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
  105. mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
  106. mask_image = np.array(mask_image, dtype=np.uint8)
  107. return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...