Skip to content

Commit

Permalink
refactor: giving QCEvaluation a state (#1164)
Browse files Browse the repository at this point in the history
* refactor: giving QCEvaluation a state

- Rename status() -> evaluate_status()
- Add state, auto-populated on validation or by evaluation.status = evaluation.evaluate_status()

* fix: ensure backward compatibility

* fix: don't exclude latest_status (didn't know what that did)

* tests: generate examples

* tests: missing test coverage on deprecated status function
  • Loading branch information
dbirman authored Nov 25, 2024
1 parent 3f0d66d commit 52378a6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 38 deletions.
9 changes: 6 additions & 3 deletions examples/quality_control.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@
],
"tags": null,
"notes": "",
"allow_failed_metrics": false
"allow_failed_metrics": false,
"latest_status": "Pending"
},
{
"modality": {
Expand Down Expand Up @@ -124,7 +125,8 @@
],
"tags": null,
"notes": "Pass when video_1_num_frames==video_2_num_frames",
"allow_failed_metrics": false
"allow_failed_metrics": false,
"latest_status": "Pass"
},
{
"modality": {
Expand Down Expand Up @@ -180,7 +182,8 @@
],
"tags": null,
"notes": null,
"allow_failed_metrics": false
"allow_failed_metrics": false,
"latest_status": "Pass"
}
],
"notes": null
Expand Down
65 changes: 42 additions & 23 deletions src/aind_data_schema/core/quality_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, List, Literal, Optional, Union
import warnings

from aind_data_schema_models.modalities import Modality
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator
Expand Down Expand Up @@ -91,8 +92,47 @@ class QCEvaluation(AindModel):
" will allow individual metrics to fail while still passing the evaluation."
),
)
latest_status: Status = Field(default=None, title="Evaluation status")

def status(self, date: datetime = datetime.now(tz=timezone.utc)) -> Status:
"""DEPRECATED
Replace with QCEvaluation.status or QCEvaluation.evaluate_status()
"""
warnings.warn(
"The status method is deprecated. Please use QCEvaluation.status or QCEvaluation.evaluate_status()",
DeprecationWarning,
)
return self.evaluate_status(date)

@property
def failed_metrics(self) -> Optional[List[QCMetric]]:
"""Return any metrics that are failing
Returns none if allow_failed_metrics is False
Returns
-------
list[QCMetric]
Metrics that fail
"""
if not self.allow_failed_metrics:
return None
else:
failing_metrics = []
for metric in self.metrics:
if metric.status.status == Status.FAIL:
failing_metrics.append(metric)

return failing_metrics

@model_validator(mode="after")
def compute_latest_status(self):
"""Compute the status of the evaluation based on the status of its metrics"""
self.latest_status = self.evaluate_status()
return self

def evaluate_status(self, date: datetime = datetime.now(tz=timezone.utc)) -> Status:
"""Loop through all metrics and return the evaluation's status
Any fail -> FAIL
Expand Down Expand Up @@ -123,27 +163,6 @@ def status(self, date: datetime = datetime.now(tz=timezone.utc)) -> Status:

return Status.PASS

@property
def failed_metrics(self) -> Optional[List[QCMetric]]:
"""Return any metrics that are failing
Returns none if allow_failed_metrics is False
Returns
-------
list[QCMetric]
Metrics that fail
"""
if not self.allow_failed_metrics:
return None
else:
failing_metrics = []
for metric in self.metrics:
if metric.status.status == Status.FAIL:
failing_metrics.append(metric)

return failing_metrics

@model_validator(mode="after")
def validate_multi_asset(cls, v):
"""Ensure that the evaluated_assets field in any attached metrics is set correctly"""
Expand Down Expand Up @@ -192,7 +211,7 @@ def status(
All PASS -> PASS
"""
if not modality and not stage and not tag:
eval_statuses = [evaluation.status(date=date) for evaluation in self.evaluations]
eval_statuses = [evaluation.evaluate_status(date=date) for evaluation in self.evaluations]
else:
if modality and not isinstance(modality, list):
modality = [modality]
Expand All @@ -202,7 +221,7 @@ def status(
tag = [tag]

eval_statuses = [
evaluation.status(date=date)
evaluation.evaluate_status(date=date)
for evaluation in self.evaluations
if (not modality or any(evaluation.modality == mod for mod in modality))
and (not stage or any(evaluation.stage == sta for sta in stage))
Expand Down
29 changes: 17 additions & 12 deletions tests/test_quality_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ class QualityControlTests(unittest.TestCase):
def test_constructors(self):
"""testing constructors"""

with self.assertRaises(ValidationError):
q = QualityControl()
self.assertRaises(ValidationError, QualityControl)

test_eval = QCEvaluation(
name="Drift map",
Expand Down Expand Up @@ -76,7 +75,7 @@ def test_overall_status(self):
)

# check that evaluation status gets auto-set if it has never been set before
self.assertEqual(test_eval.status(), Status.PASS)
self.assertEqual(test_eval.latest_status, Status.PASS)

q = QualityControl(
evaluations=[test_eval, test_eval],
Expand Down Expand Up @@ -147,7 +146,7 @@ def test_evaluation_status(self):
],
)

self.assertEqual(evaluation.status(), Status.PASS)
self.assertEqual(evaluation.latest_status, Status.PASS)

# Add a pending metric, evaluation should now evaluate to pending
evaluation.metrics.append(
Expand All @@ -163,8 +162,9 @@ def test_evaluation_status(self):
],
)
)
evaluation.latest_status = evaluation.evaluate_status()

self.assertEqual(evaluation.status(), Status.PENDING)
self.assertEqual(evaluation.latest_status, Status.PENDING)

# Add a failing metric, evaluation should now evaluate to fail
evaluation.metrics.append(
Expand All @@ -178,8 +178,10 @@ def test_evaluation_status(self):
],
)
)
evaluation.latest_status = evaluation.evaluate_status()

self.assertEqual(evaluation.status(), Status.FAIL)
self.assertEqual(evaluation.latest_status, Status.FAIL)
self.assertEqual(evaluation.status(), evaluation.latest_status)

def test_allowed_failed_metrics(self):
"""Test that if you set the flag to allow failures that evaluations pass"""
Expand Down Expand Up @@ -218,14 +220,17 @@ def test_allowed_failed_metrics(self):

evaluation.allow_failed_metrics = True

self.assertEqual(evaluation.status(), Status.PENDING)
self.assertEqual(evaluation.latest_status, Status.PENDING)

# Replace the pending evaluation with a fail, evaluation should not evaluate to pass
evaluation.metrics[1].status_history[0].status = Status.FAIL
evaluation.latest_status = evaluation.evaluate_status()

self.assertEqual(evaluation.status(), Status.PASS)
self.assertEqual(evaluation.latest_status, Status.PASS)

metric2.status_history[0].status = Status.FAIL
evaluation.latest_status = evaluation.evaluate_status()

self.assertEqual(evaluation.failed_metrics, [metric2])

def test_metric_history_order(self):
Expand Down Expand Up @@ -484,10 +489,10 @@ def test_status_date(self):
metrics=[metric],
)

self.assertRaises(ValueError, test_eval.status, date=t0_5)
self.assertEqual(test_eval.status(date=t1_5), Status.FAIL)
self.assertEqual(test_eval.status(date=t2_5), Status.PENDING)
self.assertEqual(test_eval.status(date=t3_5), Status.PASS)
self.assertRaises(ValueError, test_eval.evaluate_status, date=t0_5)
self.assertEqual(test_eval.evaluate_status(date=t1_5), Status.FAIL)
self.assertEqual(test_eval.evaluate_status(date=t2_5), Status.PENDING)
self.assertEqual(test_eval.evaluate_status(date=t3_5), Status.PASS)

qc = QualityControl(evaluations=[test_eval])

Expand Down

0 comments on commit 52378a6

Please sign in to comment.