Skip to content

Commit

Permalink
Add option to pad bounding boxes
Browse files Browse the repository at this point in the history
  • Loading branch information
thompsonmj committed Oct 9, 2024
1 parent 50fbdc2 commit 8af0232
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 17 deletions.
8 changes: 8 additions & 0 deletions src/wing_segmenter/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def main():
default='area',
help='Interpolation method to use when resizing. For upscaling, "lanczos4" is recommended.')

# Bounding box padding option
bbox_group = segment_parser.add_argument_group('Bounding Box Options')
bbox_group.add_argument('--bbox-padding', type=int, default=None,
help='Padding to add to bounding boxes in pixels. Defaults to no padding.')

# Output options within mutually exclusive group
output_group = segment_parser.add_mutually_exclusive_group()
Expand Down Expand Up @@ -103,6 +107,10 @@ def main():
if args.custom_output_dir and args.outputs_base_dir:
parser.error('Cannot specify both --outputs-base-dir and --custom-output-dir. Choose one.')

# Validate bbox-padding
if args.bbox_padding is not None and args.bbox_padding < 0:
parser.error('--bbox-padding must be a non-negative integer.')

# Execute the subcommand
if args.command == 'segment':
from wing_segmenter.segmenter import Segmenter
Expand Down
58 changes: 43 additions & 15 deletions src/wing_segmenter/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def process_image(segmenter, image_path):
working_image_rgb,
segmenter.sam_processor,
segmenter.sam_model,
segmenter.device
segmenter.device,
segmenter.bbox_padding
)

if mask is not None:
Expand Down Expand Up @@ -134,16 +135,40 @@ def process_image(segmenter, image_path):
logging.error(f"Error processing image {image_path}: {e}")
raise ImageProcessingError(f"Error processing image {image_path}: {e}")

def get_mask_SAM(result, image, processor, model, device):
def pad_bounding_boxes(boxes, padding, image_width, image_height):
"""
Generate mask using SAM model.
Pads bounding boxes by the specified number of pixels.
Parameters:
- boxes (list of list): List of bounding boxes [x1, y1, x2, y2].
- padding (int): Number of pixels to pad.
- image_width (int): Width of the image.
- image_height (int): Height of the image.
Returns:
- padded_boxes (list of list): List of padded bounding boxes.
"""
padded_boxes = []
for box in boxes:
x1, y1, x2, y2 = box
x1_padded = max(x1 - padding, 0)
y1_padded = max(y1 - padding, 0)
x2_padded = min(x2 + padding, image_width - 1)
y2_padded = min(y2 + padding, image_height - 1)
padded_boxes.append([x1_padded, y1_padded, x2_padded, y2_padded])
return padded_boxes

def get_mask_SAM(result, image, processor, model, device, bbox_padding=None):
"""
Generate mask using SAM model with optional bounding box padding.
Parameters:
- result: YOLO prediction result.
- image (np.array): The input image in RGB format.
- processor (SamProcessor): SAM processor.
- model (SamModel): SAM model.
- device (str): 'cpu' or 'cuda'.
- bbox_padding (int or None): Padding to add to bounding boxes.
Returns:
- img_mask (np.array): The generated mask.
Expand All @@ -153,6 +178,10 @@ def get_mask_SAM(result, image, processor, model, device):
bboxes_xyxy = result.boxes.xyxy
input_boxes = [bbox.cpu().numpy().tolist()[:4] for bbox in bboxes_xyxy]

if bbox_padding is not None and bbox_padding > 0:
image_height, image_width = image.shape[:2]
input_boxes = pad_bounding_boxes(input_boxes, bbox_padding, image_width, image_height)

if len(bboxes_xyxy) == 0:
logging.warning("No bounding boxes detected by YOLO.")
return None
Expand Down Expand Up @@ -233,6 +262,8 @@ def crop_and_save_by_class(segmenter, image, mask, relative_path):
- mask (np.array): The segmentation mask.
- relative_path (str): The relative path of the image for maintaining directory structure.
"""
image_height, image_width = image.shape[:2]

for class_id, class_name in CLASSES.items():
if class_id == 0:
continue # Skip background class for cropping
Expand All @@ -244,22 +275,21 @@ def crop_and_save_by_class(segmenter, image, mask, relative_path):
contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

if not contours:
# logging.info(f"No instances found for class '{class_name}' in image '{relative_path}'.") # uncomment for debugging
continue # No mask for this class

for c in contours:
x, y, w, h = cv2.boundingRect(c)

# Apply padding?
padding = 0 # could make this configurable
x = max(x - padding, 0)
y = max(y - padding, 0)
w = min(w + 2 * padding, image.shape[1] - x)
h = min(h + 2 * padding, image.shape[0] - y)
# Apply bounding box padding if specified
padding = segmenter.bbox_padding if segmenter.bbox_padding is not None else 0
x_padded = max(x - padding, 0)
y_padded = max(y - padding, 0)
x2_padded = min(x + w + padding, image_width)
y2_padded = min(y + h + padding, image_height)

# Crop the image
cropped_image = image[y:y+h, x:x+w]
cropped_mask = class_mask[y:y+h, x:x+w]
# Crop the image using the padded bounding box
cropped_image = image[y_padded:y2_padded, x_padded:x2_padded]
cropped_mask = class_mask[y_padded:y2_padded, x_padded:x2_padded]

# Prepare save path for cropped classes
crop_relative_path = os.path.join(class_name, relative_path)
Expand All @@ -268,7 +298,6 @@ def crop_and_save_by_class(segmenter, image, mask, relative_path):

# Save cropped image
cv2.imwrite(crop_save_path, cropped_image)
# logging.info(f"Cropped '{class_name}' saved to '{crop_save_path}'.") # uncomment for debugging

# Background removal from cropped image if specified
if segmenter.remove_crops_background:
Expand All @@ -280,7 +309,6 @@ def crop_and_save_by_class(segmenter, image, mask, relative_path):

# Save the background-removed cropped image
cv2.imwrite(crop_bg_removed_save_path, cropped_image_bg_removed)
# logging.info(f"Cropped '{class_name}' with background removed saved to '{crop_bg_removed_save_path}'.") # uncomment for debugging

def remove_background(image, mask, bg_color='black'):
"""
Expand Down
54 changes: 54 additions & 0 deletions src/wing_segmenter/process_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
import os

def load_metadata(metadata_path):
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata = json.load(f)
return metadata
else:
return None

def determine_additional_steps(old_params, new_params):
"""
Determines which additional steps need to be performed based on the difference
between old and new run parameters.
Returns:
- reprocess_segmentation (bool): Whether segmentation needs to be reprocessed.
- additional_steps (list): List of additional steps to perform.
"""
reprocess_segmentation = False
additional_steps = []

# Parameters that affect the main segmentation process
segmentation_params = ['sam_model_name', 'yolo_model_name', 'resize_mode', 'size', 'padding_color', 'interpolation']

# Check if any main segmentation parameters have changed
for param in segmentation_params:
old_value = old_params.get(param)
new_value = new_params.get(param)
if old_value != new_value:
reprocess_segmentation = True
break

# Parameters for additional processing steps
processing_flags = ['visualize_segmentation', 'crop_by_class', 'remove_crops_background', 'remove_full_background', 'background_color']

# Determine additional steps to perform
for flag in processing_flags:
old_value = old_params.get(flag, False)
new_value = new_params.get(flag, False)
if new_value and not old_value:
additional_steps.append(flag)

return reprocess_segmentation, additional_steps

def update_metadata_run_parameters(metadata_path, new_params):
with open(metadata_path, 'r') as f:
metadata = json.load(f)

metadata['run_parameters'].update(new_params)

with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=4)
8 changes: 7 additions & 1 deletion src/wing_segmenter/run_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def scan_runs(dataset_path, output_base_dir=None, custom_output_dir=None):
table.add_column("Resize Dims", justify="center", no_wrap=False, width=11)
table.add_column("Resize Mode", justify="center", no_wrap=False, width=7)
table.add_column("Interp", justify="center", no_wrap=True, min_width=13)
table.add_column("BBox Pad", justify="right", no_wrap=False, width=4)
table.add_column("Errors", justify="center", no_wrap=True, min_width=6)

for idx, run_dir in enumerate(run_dirs, 1):
metadata_path = os.path.join(run_dir, 'metadata.json')
if not os.path.exists(metadata_path):
table.add_row(str(idx), f"Missing metadata.json", "", "", "", "", "", "")
table.add_row(str(idx), f"Missing metadata.json", "", "", "", "", "", "", "")
continue

with open(metadata_path, 'r') as f:
Expand All @@ -81,6 +82,10 @@ def scan_runs(dataset_path, output_base_dir=None, custom_output_dir=None):

interpolation = str(metadata['run_parameters'].get('interpolation', 'None'))

# Extract bbox_padding
bbox_padding = metadata['run_parameters'].get('bbox_padding', None)
bbox_padding_str = str(bbox_padding) if bbox_padding is not None else "null"

# Extract any errors
errors = str(metadata['run_status'].get('errors', 'None'))

Expand All @@ -96,6 +101,7 @@ def scan_runs(dataset_path, output_base_dir=None, custom_output_dir=None):
resize_dims_str,
resize_mode,
interpolation,
bbox_padding_str,
errors
)

Expand Down
5 changes: 4 additions & 1 deletion src/wing_segmenter/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, config):
self.background_color = config.background_color if config.background_color else 'black'
else:
self.background_color = None
self.bbox_padding = config.bbox_padding
self.segmentation_info = []
self.output_base_dir = os.path.abspath(config.outputs_base_dir) if config.outputs_base_dir else None
self.custom_output_dir = os.path.abspath(config.custom_output_dir) if config.custom_output_dir else None
Expand All @@ -42,6 +43,7 @@ def __init__(self, config):
'yolo_model_name': self.config.yolo_model,
'resize_mode': self.resize_mode,
'size': self.size,
'bbox_padding': self.bbox_padding,
})

setup_paths(self)
Expand Down Expand Up @@ -87,7 +89,8 @@ def process_dataset(self):
'crop_by_class': self.crop_by_class,
'remove_crops_background': self.remove_crops_background,
'remove_full_background': self.remove_full_background,
'background_color': self.background_color
'background_color': self.background_color,
'bbox_padding': self.bbox_padding,
},
'run_hardware': get_run_hardware_info(self.device),
'run_status': {
Expand Down

0 comments on commit 8af0232

Please sign in to comment.