Skip to content

Commit

Permalink
feat(metric): add detection, precision, and recall diarization metrics (
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Dec 13, 2024
1 parent ffc43ee commit 0b7f933
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Clipping and speaker/source alignment issues in speech separation pipeline have
- feat(utils): add `hidden` option to `ProgressHook`
- feat(utils): add `FilterByNumberOfSpeakers` protocol files filter
- feat(core): add `Calibration` class to calibrate logits/distances into probabilities
- feat(metric): add detection, precision, and recall diarization metrics

### Improvements

Expand Down
6 changes: 6 additions & 0 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
from pyannote.audio.core.task import Problem, Resolution, Specifications
from pyannote.audio.tasks.segmentation.mixins import SegmentationTask
from pyannote.audio.torchmetrics import (
DetectionErrorRate,
DiarizationErrorRate,
DiarizationPrecision,
DiarizationRecall,
FalseAlarmRate,
MissedDetectionRate,
SpeakerConfusionRate,
Expand Down Expand Up @@ -493,6 +496,9 @@ def default_metric(
"DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5),
"DiarizationErrorRate/Miss": MissedDetectionRate(0.5),
"DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5),
"DiarizationErrorRate/Precision": DiarizationPrecision(0.5),
"DiarizationErrorRate/Recall": DiarizationRecall(0.5),
"DiarizationErrorRate/DetectionErrorRate": DetectionErrorRate(0.5),
}

# TODO: no need to compute gradient in this method
Expand Down
8 changes: 7 additions & 1 deletion pyannote/audio/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@


from .audio.diarization_error_rate import (
DetectionErrorRate,
DiarizationErrorRate,
DiarizationPrecision,
DiarizationRecall,
FalseAlarmRate,
MissedDetectionRate,
OptimalDiarizationErrorRate,
Expand All @@ -34,13 +37,16 @@
)

__all__ = [
"DetectionErrorRate",
"DiarizationErrorRate",
"DiarizationPrecision",
"DiarizationRecall",
"FalseAlarmRate",
"MissedDetectionRate",
"SpeakerConfusionRate",
"OptimalDiarizationErrorRate",
"OptimalFalseAlarmRate",
"OptimalMissedDetectionRate",
"OptimalSpeakerConfusionRate",
"OptimalDiarizationErrorRateThreshold",
"SpeakerConfusionRate",
]
14 changes: 10 additions & 4 deletions pyannote/audio/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@


from .diarization_error_rate import (
DetectionErrorRate,
DiarizationErrorRate,
DiarizationPrecision,
DiarizationRecall,
FalseAlarmRate,
MissedDetectionRate,
OptimalDiarizationErrorRate,
Expand All @@ -34,13 +37,16 @@
)

__all__ = [
"DetectionErrorRate",
"DiarizationErrorRate",
"SpeakerConfusionRate",
"MissedDetectionRate",
"DiarizationPrecision",
"DiarizationRecall",
"FalseAlarmRate",
"MissedDetectionRate",
"OptimalDiarizationErrorRate",
"OptimalSpeakerConfusionRate",
"OptimalMissedDetectionRate",
"OptimalFalseAlarmRate",
"OptimalMissedDetectionRate",
"OptimalSpeakerConfusionRate",
"OptimalDiarizationErrorRateThreshold",
"SpeakerConfusionRate",
]
111 changes: 104 additions & 7 deletions pyannote/audio/torchmetrics/audio/diarization_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ class DiarizationErrorRate(Metric):
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
Notes
-----
While pyannote.audio conventions is to store speaker activations with
(batch_size, num_frames, num_speakers)-shaped tensors, this torchmetrics metric
expects them to be shaped as (batch_size, num_speakers, num_frames) tensors.
"""

higher_is_better = False
Expand All @@ -68,7 +62,7 @@ def update(
preds: torch.Tensor,
target: torch.Tensor,
) -> None:
"""Compute and accumulate components of diarization error rate
"""Compute and accumulate diarization error rate components
Parameters
----------
Expand All @@ -95,6 +89,8 @@ def update(
self.speech_total += speech_total

def compute(self):
"""Compute diarization error rate from its accumulated components"""

return _der_compute(
self.false_alarm,
self.missed_detection,
Expand All @@ -104,20 +100,113 @@ def compute(self):


class SpeakerConfusionRate(DiarizationErrorRate):
"""Speaker confusion rate (one of the three summands of diarization error rate)
Parameters
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
"""

higher_is_better = False

def compute(self):
"""Compute speaker confusion rate from its accumulated components"""
return self.speaker_confusion / (self.speech_total + 1e-8)


class DiarizationPrecision(DiarizationErrorRate):
"""Precision of speaker identification
This metric is computed as the durations ratio of correctly identified speech
over correctly detected speech. As such it does not account for false alarms.
Parameters
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
"""

higher_is_better = True

def compute(self):
"""Compute precision of speaker identification from its accumulated components"""
correctly_detected_speech = self.speech_total - self.missed_detection
correctly_identified_speech = correctly_detected_speech - self.speaker_confusion
return correctly_identified_speech / (correctly_detected_speech + 1e-8)


class DiarizationRecall(DiarizationErrorRate):
"""Recall of speaker identification
This metric is computed as the durations ratio of correctly identified speech
over total speech in reference. As such it does not account for false alarms.
Parameters
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
"""

higher_is_better = True

def compute(self):
"""Compute recall of speaker identification from its accumulated components"""
correctly_detected_speech = self.speech_total - self.missed_detection
correctly_identified_speech = correctly_detected_speech - self.speaker_confusion
return correctly_identified_speech / (self.speech_total + 1e-8)


class FalseAlarmRate(DiarizationErrorRate):
"""False alarm rate (one of the three summands of diarization error rate)
Parameters
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
"""

higher_is_better = False

def compute(self):
"""Compute false alarm rate from its accumulated components"""
return self.false_alarm / (self.speech_total + 1e-8)


class MissedDetectionRate(DiarizationErrorRate):
"""Missed detection rate (one of the three summands of diarization error rate)
Parameters
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
"""

higher_is_better = False

def compute(self):
"""Compute missed detection rate from its accumulated components"""
return self.missed_detection / (self.speech_total + 1e-8)


class DetectionErrorRate(DiarizationErrorRate):
"""Detection error rate
This metric is computed as the sum of false alarm and missed detection rates.
Parameters
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
"""

higher_is_better = False

def compute(self):
"""Compute detection error rate from its accumulated components"""
return (self.false_alarm + self.missed_detection) / (self.speech_total + 1e-8)


class OptimalDiarizationErrorRate(Metric):
"""Optiml Diarization error rate
Expand Down Expand Up @@ -209,6 +298,8 @@ def compute(self):


class OptimalDiarizationErrorRateThreshold(OptimalDiarizationErrorRate):
higher_is_better = False

def compute(self):
der = _der_compute(
self.FalseAlarm,
Expand All @@ -223,6 +314,8 @@ def compute(self):


class OptimalSpeakerConfusionRate(OptimalDiarizationErrorRate):
higher_is_better = False

def compute(self):
der = _der_compute(
self.FalseAlarm,
Expand All @@ -235,6 +328,8 @@ def compute(self):


class OptimalFalseAlarmRate(OptimalDiarizationErrorRate):
higher_is_better = False

def compute(self):
der = _der_compute(
self.FalseAlarm,
Expand All @@ -247,6 +342,8 @@ def compute(self):


class OptimalMissedDetectionRate(OptimalDiarizationErrorRate):
higher_is_better = False

def compute(self):
der = _der_compute(
self.FalseAlarm,
Expand Down

0 comments on commit 0b7f933

Please sign in to comment.