Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added detection metadata #1589

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
24 changes: 22 additions & 2 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
extract_ultralytics_masks,
get_data_item,
is_data_equal,
is_metadata_equal,
mask_to_xyxy,
merge_data,
merge_metadata,
process_roboflow_result,
xywh_to_xyxy,
)
Expand Down Expand Up @@ -125,6 +127,9 @@ class simplifies data manipulation and filtering, providing a uniform API for
data (Dict[str, Union[np.ndarray, List]]): A dictionary containing additional
data where each key is a string representing the data type, and the value
is either a NumPy array or a list of corresponding data.
metadata (Dict[str, Any]): A dictionary containing collection-level metadata
that applies to the entire set of detections. This may include information such
as the video name, camera parameters, timestamp, or other global metadata.
""" # noqa: E501 // docs

xyxy: np.ndarray
Expand All @@ -133,6 +138,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
class_id: Optional[np.ndarray] = None
tracker_id: Optional[np.ndarray] = None
data: Dict[str, Union[np.ndarray, List]] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
validate_detections_fields(
souhhmm marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -185,6 +191,7 @@ def __eq__(self, other: Detections):
np.array_equal(self.confidence, other.confidence),
np.array_equal(self.tracker_id, other.tracker_id),
is_data_equal(self.data, other.data),
is_metadata_equal(self.metadata, other.metadata),
]
)

Expand Down Expand Up @@ -958,7 +965,7 @@ def from_ncnn(cls, ncnn_results) -> Detections:
)

@classmethod
def empty(cls) -> Detections:
def empty(cls, metadata: Optional[Dict[str, Any]] = None) -> Detections:
"""
Create an empty Detections object with no bounding boxes,
confidences, or class IDs.
Expand All @@ -973,10 +980,14 @@ def empty(cls) -> Detections:
empty_detections = Detections.empty()
```
"""
if metadata is not None and not isinstance(metadata, dict):
raise TypeError("Metadata must be a dictionary.")

return cls(
xyxy=np.empty((0, 4), dtype=np.float32),
confidence=np.array([], dtype=np.float32),
class_id=np.array([], dtype=int),
metadata=metadata if metadata is not None else {},
)

def is_empty(self) -> bool:
Expand Down Expand Up @@ -1041,12 +1052,16 @@ def merge(cls, detections_list: List[Detections]) -> Detections:
array([0.1, 0.2, 0.3])
```
"""
metadata_list = [detections.metadata for detections in detections_list]

detections_list = [
detections for detections in detections_list if not detections.is_empty()
]

metadata = merge_metadata(metadata_list)

if len(detections_list) == 0:
return Detections.empty()
return Detections.empty(metadata=metadata)

for detections in detections_list:
validate_detections_fields(
Expand Down Expand Up @@ -1085,6 +1100,7 @@ def stack_or_none(name: str):
class_id=class_id,
tracker_id=tracker_id,
data=data,
metadata=metadata,
)

def get_anchors_coordinates(self, anchor: Position) -> np.ndarray:
Expand Down Expand Up @@ -1198,6 +1214,7 @@ def __getitem__(
class_id=self.class_id[index] if self.class_id is not None else None,
tracker_id=self.tracker_id[index] if self.tracker_id is not None else None,
data=get_data_item(self.data, index),
metadata=self.metadata,
)

def __setitem__(self, key: str, value: Union[np.ndarray, List]):
Expand Down Expand Up @@ -1459,13 +1476,16 @@ def merge_inner_detection_object_pair(
else:
winning_detection = detections_2

metadata = merge_metadata([detections_1.metadata, detections_2.metadata])

return Detections(
xyxy=merged_xyxy,
mask=merged_mask,
confidence=merged_confidence,
class_id=winning_detection.class_id,
tracker_id=winning_detection.tracker_id,
data=winning_detection.data,
metadata=metadata,
)


Expand Down
57 changes: 56 additions & 1 deletion supervision/detection/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -808,6 +808,25 @@ def is_data_equal(data_a: Dict[str, np.ndarray], data_b: Dict[str, np.ndarray])
)


def is_metadata_equal(metadata_a: Dict[str, Any], metadata_b: Dict[str, Any]) -> bool:
"""
Compares the metadata payloads of two Detections instances.

Args:
metadata_a, metadata_b: The metadata payloads of the instances.

Returns:
True if the metadata payloads are equal, False otherwise.
"""
return set(metadata_a.keys()) == set(metadata_b.keys()) and all(
np.array_equal(metadata_a[key], metadata_b[key])
if isinstance(metadata_a[key], np.ndarray)
and isinstance(metadata_b[key], np.ndarray)
else metadata_a[key] == metadata_b[key]
for key in metadata_a
)


def merge_data(
data_list: List[Dict[str, Union[npt.NDArray[np.generic], List]]],
) -> Dict[str, Union[npt.NDArray[np.generic], List]]:
Expand Down Expand Up @@ -866,6 +885,42 @@ def merge_data(
return merged_data


def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]:
souhhmm marked this conversation as resolved.
Show resolved Hide resolved
"""
Merge metadata from a list of metadata dictionaries.

This function combines the metadata dictionaries. If a key appears in more than one
dictionary, the values must be identical for the merge to succeed.

Args:
metadata_list (List[Dict[str, Any]]): A list of metadata dictionaries to merge.

Returns:
Dict[str, Any]: A single merged metadata dictionary.

Raises:
ValueError: If there are conflicting values for the same key or if
dictionaries have different keys.
"""
if not metadata_list:
return {}

all_keys_sets = [set(metadata.keys()) for metadata in metadata_list]
if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets):
raise ValueError("All metadata dictionaries must have the same keys to merge.")

merged_metadata = {}
for metadata in metadata_list:
for key, value in metadata.items():
if key in merged_metadata:
if merged_metadata[key] != value:
raise ValueError(f"Conflicting metadata for key: '{key}'.")
else:
merged_metadata[key] = value

return merged_metadata


def get_data_item(
data: Dict[str, Union[np.ndarray, List]],
index: Union[int, slice, List[int], np.ndarray],
Expand Down