Skip to content

Commit

Permalink
Merge pull request #1577 from AHuzail/fix/class-agnostic
Browse files Browse the repository at this point in the history
Bugfix: Class-agnostic mAP
  • Loading branch information
LinasKo authored Oct 10, 2024
2 parents 70869c7 + c5dfb47 commit 70bc4b9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
1 change: 0 additions & 1 deletion supervision/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from supervision.metrics.core import (
CLASS_ID_NONE,
AveragingMethod,
Metric,
MetricTarget,
Expand Down
3 changes: 0 additions & 3 deletions supervision/metrics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from enum import Enum
from typing import Any

CLASS_ID_NONE = -1
"""Used by metrics module as class ID, when none is present"""


class Metric(ABC):
"""
Expand Down
16 changes: 16 additions & 0 deletions supervision/metrics/mean_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ def update(
f" targets ({len(targets)}) during the update must be the same."
)

if self._class_agnostic:
predictions = deepcopy(predictions)
targets = deepcopy(targets)

for prediction in predictions:
prediction.class_id[:] = -1
for target in targets:
target.class_id[:] = -1

self._predictions_list.extend(predictions)
self._targets_list.extend(targets)

Expand Down Expand Up @@ -180,6 +189,7 @@ def _compute(
matches = self._match_detection_batch(
predictions.class_id, targets.class_id, iou, iou_thresholds
)

stats.append(
(
matches,
Expand All @@ -203,6 +213,7 @@ def _compute(

return MeanAveragePrecisionResult(
metric_target=self._metric_target,
is_class_agnostic=self._class_agnostic,
mAP_scores=mAP_scores,
iou_thresholds=iou_thresholds,
matched_classes=unique_classes,
Expand Down Expand Up @@ -245,6 +256,7 @@ def _match_detection_batch(
iou_thresholds.shape[0],
)
correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)

correct_class = target_classes[:, None] == predictions_classes

for i, iou_level in enumerate(iou_thresholds):
Expand Down Expand Up @@ -383,6 +395,8 @@ class MeanAveragePrecisionResult:
Attributes:
metric_target (MetricTarget): the type of data used for the metric -
boxes, masks or oriented bounding boxes.
class_agnostic (bool): When computing class-agnostic results, class ID
is set to `-1`.
mAP_map50_95 (float): the mAP score at IoU thresholds from `0.5` to `0.95`.
mAP_map50 (float): the mAP score at IoU threshold of `0.5`.
mAP_map75 (float): the mAP score at IoU threshold of `0.75`.
Expand All @@ -402,6 +416,7 @@ class and IoU threshold. Shape: `(num_target_classes, num_iou_thresholds)`
"""

metric_target: MetricTarget
is_class_agnostic: bool

@property
def map50_95(self) -> float:
Expand Down Expand Up @@ -436,6 +451,7 @@ def __str__(self) -> str:
out_str = (
f"{self.__class__.__name__}:\n"
f"Metric target: {self.metric_target}\n"
f"Class agnostic: {self.is_class_agnostic}\n"
f"mAP @ 50:95: {self.map50_95:.4f}\n"
f"mAP @ 50: {self.map50:.4f}\n"
f"mAP @ 75: {self.map75:.4f}\n"
Expand Down

0 comments on commit 70bc4b9

Please sign in to comment.