Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added detection metadata #1589

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
31 changes: 26 additions & 5 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_data_equal,
mask_to_xyxy,
merge_data,
merge_metadata,
process_roboflow_result,
xywh_to_xyxy,
)
Expand Down Expand Up @@ -125,6 +126,9 @@ class simplifies data manipulation and filtering, providing a uniform API for
data (Dict[str, Union[np.ndarray, List]]): A dictionary containing additional
data where each key is a string representing the data type, and the value
is either a NumPy array or a list of corresponding data.
metadata (Dict[str, Any]): A dictionary containing collection-level metadata
that applies to the entire set of detections. This may include information such
as the video name, camera parameters, timestamp, or other global metadata.
""" # noqa: E501 // docs

xyxy: np.ndarray
Expand All @@ -133,6 +137,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
class_id: Optional[np.ndarray] = None
tracker_id: Optional[np.ndarray] = None
data: Dict[str, Union[np.ndarray, List]] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
validate_detections_fields(
souhhmm marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -185,6 +190,7 @@ def __eq__(self, other: Detections):
np.array_equal(self.confidence, other.confidence),
np.array_equal(self.tracker_id, other.tracker_id),
is_data_equal(self.data, other.data),
self.metadata == other.metadata,
souhhmm marked this conversation as resolved.
Show resolved Hide resolved
]
)

Expand Down Expand Up @@ -958,7 +964,7 @@ def from_ncnn(cls, ncnn_results) -> Detections:
)

@classmethod
def empty(cls) -> Detections:
def empty(cls, metadata: Optional[Dict[str, Any]] = None) -> Detections:
"""
Create an empty Detections object with no bounding boxes,
confidences, or class IDs.
Expand All @@ -977,15 +983,21 @@ def empty(cls) -> Detections:
xyxy=np.empty((0, 4), dtype=np.float32),
confidence=np.array([], dtype=np.float32),
class_id=np.array([], dtype=int),
metadata=metadata if metadata is not None else {},
)

def is_empty(self) -> bool:
"""
Returns `True` if the `Detections` object is considered empty.
"""
empty_detections = Detections.empty()
empty_detections.data = self.data
return self == empty_detections
return (
len(self.xyxy) == 0
and (self.mask is None or len(self.mask) == 0)
and (self.class_id is None or len(self.class_id) == 0)
and (self.confidence is None or len(self.confidence) == 0)
and (self.tracker_id is None or len(self.tracker_id) == 0)
and not self.data
)
souhhmm marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def merge(cls, detections_list: List[Detections]) -> Detections:
Expand Down Expand Up @@ -1041,12 +1053,16 @@ def merge(cls, detections_list: List[Detections]) -> Detections:
array([0.1, 0.2, 0.3])
```
"""
metadata_list = [detections.metadata for detections in detections_list]

detections_list = [
detections for detections in detections_list if not detections.is_empty()
]

metadata = merge_metadata(metadata_list)

if len(detections_list) == 0:
return Detections.empty()
return Detections.empty(metadata=metadata)

for detections in detections_list:
validate_detections_fields(
Expand Down Expand Up @@ -1085,6 +1101,7 @@ def stack_or_none(name: str):
class_id=class_id,
tracker_id=tracker_id,
data=data,
metadata=metadata,
)

def get_anchors_coordinates(self, anchor: Position) -> np.ndarray:
Expand Down Expand Up @@ -1198,6 +1215,7 @@ def __getitem__(
class_id=self.class_id[index] if self.class_id is not None else None,
tracker_id=self.tracker_id[index] if self.tracker_id is not None else None,
data=get_data_item(self.data, index),
metadata=self.metadata,
)

def __setitem__(self, key: str, value: Union[np.ndarray, List]):
Expand Down Expand Up @@ -1459,13 +1477,16 @@ def merge_inner_detection_object_pair(
else:
winning_detection = detections_2

metadata = merge_metadata([detections_1.metadata, detections_2.metadata])

return Detections(
xyxy=merged_xyxy,
mask=merged_mask,
confidence=merged_confidence,
class_id=winning_detection.class_id,
tracker_id=winning_detection.tracker_id,
data=winning_detection.data,
metadata=metadata,
)


Expand Down
31 changes: 30 additions & 1 deletion supervision/detection/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -866,6 +866,35 @@ def merge_data(
return merged_data


def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]:
souhhmm marked this conversation as resolved.
Show resolved Hide resolved
"""
Merge metadata from a list of metadata dictionaries.

This function combines the metadata dictionaries. If a key appears in more than one
dictionary, the values must be identical for the merge to succeed.

Args:
metadata_list (List[Dict[str, Any]]): A list of metadata dictionaries to merge.

Returns:
Dict[str, Any]: A single merged metadata dictionary.

Raises:
ValueError: If there are conflicting values for the same key.
"""
merged_metadata = {}

for metadata in metadata_list:
for key, value in metadata.items():
if key in merged_metadata:
if merged_metadata[key] != value:
raise ValueError(f"Conflicting metadata for key: {key}.")
else:
merged_metadata[key] = value

return merged_metadata


def get_data_item(
data: Dict[str, Union[np.ndarray, List]],
index: Union[int, slice, List[int], np.ndarray],
Expand Down
Loading