Skip to content

Commit

Permalink
Support modality and stage filters (#1124)
Browse files Browse the repository at this point in the history
* feat: support modality and stage filters

* chore :lint

* feat: upgraded filtering, tests broken

* feat: adding date checks

* fix: code got lost

* fix: more lost code

* fix: missing f"

* tests: typo

* fix: bug in status calculation based on time

* chore: lint

* chore: fix for py3.8
  • Loading branch information
dbirman authored Nov 7, 2024
1 parent bb33e44 commit f6ef0c3
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 32 deletions.
47 changes: 39 additions & 8 deletions src/aind_data_schema/core/quality_control.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
""" Schemas for Quality Metrics """

from datetime import datetime, timezone
from enum import Enum
from typing import Any, List, Literal, Optional
from typing import Any, List, Literal, Optional, Union

from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.modalities import Modality, ModalityModel
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator

from aind_data_schema.base import AindCoreModel, AindModel, AwareDatetimeWithDefault
Expand Down Expand Up @@ -91,8 +92,7 @@ class QCEvaluation(AindModel):
),
)

@property
def status(self) -> Status:
def status(self, date: datetime = datetime.now(tz=timezone.utc)) -> Status:
"""Loop through all metrics and return the evaluation's status
Any fail -> FAIL
Expand All @@ -104,7 +104,17 @@ def status(self) -> Status:
Status
Current status of the evaluation
"""
latest_metric_statuses = [metric.status.status for metric in self.metrics]
latest_metric_statuses = []

for metric in self.metrics:
# loop backwards through metric statuses until you find one that is before the provided date
for status in reversed(metric.status_history):
if status.timestamp <= date:
latest_metric_statuses.append(status.status)
break

if not latest_metric_statuses:
raise ValueError(f"No status existed prior to the provided date {date.isoformat()}")

if (not self.allow_failed_metrics) and any(status == Status.FAIL for status in latest_metric_statuses):
return Status.FAIL
Expand Down Expand Up @@ -168,15 +178,36 @@ class QualityControl(AindCoreModel):
evaluations: List[QCEvaluation] = Field(..., title="Evaluations")
notes: Optional[str] = Field(default=None, title="Notes")

@property
def status(self) -> Status:
def status(
self,
modality: Union[ModalityModel, List[ModalityModel], None] = None,
stage: Union[Stage, List[Stage], None] = None,
tag: Union[str, List[str], None] = None,
date: datetime = datetime.now(tz=timezone.utc),
) -> Status:
"""Loop through all evaluations and return the overall status
Any FAIL -> FAIL
If no fails, then any PENDING -> PENDING
All PASS -> PASS
"""
eval_statuses = [evaluation.status for evaluation in self.evaluations]
if not modality and not stage and not tag:
eval_statuses = [evaluation.status(date=date) for evaluation in self.evaluations]
else:
if modality and not isinstance(modality, list):
modality = [modality]
if stage and not isinstance(stage, list):
stage = [stage]
if tag and not isinstance(tag, list):
tag = [tag]

eval_statuses = [
evaluation.status(date=date)
for evaluation in self.evaluations
if (not modality or any(evaluation.modality == mod for mod in modality))
and (not stage or any(evaluation.stage == sta for sta in stage))
and (not tag or (evaluation.tags and any(t in evaluation.tags for t in tag)))
]

if any(status == Status.FAIL for status in eval_statuses):
return Status.FAIL
Expand Down
5 changes: 1 addition & 4 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ def test_aind_generic_constructor(self):

def test_aind_generic_validate_fieldnames(self):
"""Tests that fieldnames are validated in AindGeneric"""
expected_error = (
"1 validation error for AindGeneric\n"
" Value error, Field names cannot contain '.' or '$' "
)
expected_error = "1 validation error for AindGeneric\n" " Value error, Field names cannot contain '.' or '$' "
invalid_params = [
{"$foo": "bar"},
{"foo": {"foo.name": "bar"}},
Expand Down
13 changes: 3 additions & 10 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,7 @@ def test_validate_rig_session_compatibility(self):
)

def test_validate_old_schema_version(self):
"""Tests that old schema versions are ignored during validation
"""
"""Tests that old schema versions are ignored during validation"""
m = Metadata.model_construct(
name="name",
location="location",
Expand Down Expand Up @@ -498,16 +497,10 @@ def test_create_from_core_jsons_invalid(self, mock_warning: MagicMock):

@patch("logging.warning")
@patch("aind_data_schema.core.metadata.is_dict_corrupt")
def test_create_from_core_jsons_corrupt(
self,
mock_is_dict_corrupt: MagicMock,
mock_warning: MagicMock
):
def test_create_from_core_jsons_corrupt(self, mock_is_dict_corrupt: MagicMock, mock_warning: MagicMock):
"""Tests metadata json creation ignores corrupt core jsons"""
# mock corrupt procedures and processing
mock_is_dict_corrupt.side_effect = lambda x: (
x == self.procedures_json or x == self.processing_json
)
mock_is_dict_corrupt.side_effect = lambda x: (x == self.procedures_json or x == self.processing_json)
core_jsons = {
"subject": self.subject_json,
"data_description": None,
Expand Down
149 changes: 139 additions & 10 deletions tests/test_quality_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aind_data_schema_models.modalities import Modality
from pydantic import ValidationError

from aind_data_schema.core.quality_control import QCEvaluation, QualityControl, QCMetric, Stage, Status, QCStatus
from aind_data_schema.core.quality_control import QCEvaluation, QCMetric, QCStatus, QualityControl, Stage, Status


class QualityControlTests(unittest.TestCase):
Expand Down Expand Up @@ -76,14 +76,14 @@ def test_overall_status(self):
)

# check that evaluation status gets auto-set if it has never been set before
self.assertEqual(test_eval.status, Status.PASS)
self.assertEqual(test_eval.status(), Status.PASS)

q = QualityControl(
evaluations=[test_eval, test_eval],
)

# check that overall status gets auto-set if it has never been set before
self.assertEqual(q.status, Status.PASS)
self.assertEqual(q.status(), Status.PASS)

# Add a pending metric to the first evaluation
q.evaluations[0].metrics.append(
Expand All @@ -100,7 +100,7 @@ def test_overall_status(self):
)
)

self.assertEqual(q.status, Status.PENDING)
self.assertEqual(q.status(), Status.PENDING)

# Add a failing metric to the first evaluation
q.evaluations[0].metrics.append(
Expand All @@ -115,7 +115,7 @@ def test_overall_status(self):
)
)

self.assertEqual(q.status, Status.FAIL)
self.assertEqual(q.status(), Status.FAIL)

def test_evaluation_status(self):
"""test that evaluation status goes to pass/pending/fail correctly"""
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_evaluation_status(self):
],
)

self.assertEqual(evaluation.status, Status.PASS)
self.assertEqual(evaluation.status(), Status.PASS)

# Add a pending metric, evaluation should now evaluate to pending
evaluation.metrics.append(
Expand All @@ -164,7 +164,7 @@ def test_evaluation_status(self):
)
)

self.assertEqual(evaluation.status, Status.PENDING)
self.assertEqual(evaluation.status(), Status.PENDING)

# Add a failing metric, evaluation should now evaluate to fail
evaluation.metrics.append(
Expand All @@ -179,7 +179,7 @@ def test_evaluation_status(self):
)
)

self.assertEqual(evaluation.status, Status.FAIL)
self.assertEqual(evaluation.status(), Status.FAIL)

def test_allowed_failed_metrics(self):
"""Test that if you set the flag to allow failures that evaluations pass"""
Expand Down Expand Up @@ -218,12 +218,12 @@ def test_allowed_failed_metrics(self):

evaluation.allow_failed_metrics = True

self.assertEqual(evaluation.status, Status.PENDING)
self.assertEqual(evaluation.status(), Status.PENDING)

# Replace the pending evaluation with a fail, evaluation should not evaluate to pass
evaluation.metrics[1].status_history[0].status = Status.FAIL

self.assertEqual(evaluation.status, Status.PASS)
self.assertEqual(evaluation.status(), Status.PASS)

metric2.status_history[0].status = Status.FAIL
self.assertEqual(evaluation.failed_metrics, [metric2])
Expand Down Expand Up @@ -367,6 +367,135 @@ def test_multi_session(self):

self.assertTrue("is in a multi-asset QCEvaluation and must have evaluated_assets" in repr(context.exception))

def test_status_filters(self):
"""Test that QualityControl.status(modality, stage) filters correctly"""

test_eval = QCEvaluation(
name="Drift map",
modality=Modality.ECEPHYS,
stage=Stage.PROCESSING,
metrics=[
QCMetric(
name="Multiple values example",
value={"stuff": "in_a_dict"},
status_history=[
QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PASS)
],
),
QCMetric(
name="Drift map pass/fail",
value=False,
description="Manual evaluation of whether the drift map looks good",
reference="s3://some-data-somewhere",
status_history=[
QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PASS)
],
),
],
)
test_eval2 = QCEvaluation(
name="Drift map",
modality=Modality.BEHAVIOR,
stage=Stage.RAW,
metrics=[
QCMetric(
name="Multiple values example",
value={"stuff": "in_a_dict"},
status_history=[
QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.FAIL)
],
),
QCMetric(
name="Drift map pass/fail",
value=False,
description="Manual evaluation of whether the drift map looks good",
reference="s3://some-data-somewhere",
status_history=[
QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PASS)
],
),
],
)
test_eval3 = QCEvaluation(
name="Drift map",
modality=Modality.BEHAVIOR_VIDEOS,
tags=["tag1"],
stage=Stage.RAW,
metrics=[
QCMetric(
name="Multiple values example",
value={"stuff": "in_a_dict"},
status_history=[
QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PENDING)
],
),
QCMetric(
name="Drift map pass/fail",
value=False,
description="Manual evaluation of whether the drift map looks good",
reference="s3://some-data-somewhere",
status_history=[
QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PASS)
],
),
],
)

# Confirm that the status filters work
q = QualityControl(
evaluations=[test_eval, test_eval2, test_eval3],
)

self.assertEqual(q.status(), Status.FAIL)
self.assertEqual(q.status(modality=Modality.BEHAVIOR), Status.FAIL)
self.assertEqual(q.status(modality=Modality.ECEPHYS), Status.PASS)
self.assertEqual(q.status(modality=[Modality.ECEPHYS, Modality.BEHAVIOR]), Status.FAIL)
self.assertEqual(q.status(stage=Stage.RAW), Status.FAIL)
self.assertEqual(q.status(stage=Stage.PROCESSING), Status.PASS)
self.assertEqual(q.status(tag="tag1"), Status.PENDING)

def test_status_date(self):
"""QualityControl.status(date=) / QCEvaluation.status(date=)
should return the correct status for the given date
"""

t0_5 = datetime.fromisoformat("0500-01-01 00:00:00+00:00")
t1 = datetime.fromisoformat("1000-01-01 00:00:00+00:00")
t1_5 = datetime.fromisoformat("1500-01-01 00:00:00+00:00")
t2 = datetime.fromisoformat("2000-01-01 00:00:00+00:00")
t2_5 = datetime.fromisoformat("2500-01-01 00:00:00+00:00")
t3 = datetime.fromisoformat("3000-01-01 00:00:00+00:00")
t3_5 = datetime.fromisoformat("3500-01-01 00:00:00+00:00")

metric = QCMetric(
name="Drift map pass/fail",
value=False,
status_history=[
QCStatus(evaluator="Bob", timestamp=t1, status=Status.FAIL),
QCStatus(evaluator="Bob", timestamp=t2, status=Status.PENDING),
QCStatus(evaluator="Bob", timestamp=t3, status=Status.PASS),
],
)

test_eval = QCEvaluation(
name="Drift map",
modality=Modality.ECEPHYS,
stage=Stage.PROCESSING,
metrics=[metric],
)

self.assertRaises(ValueError, test_eval.status, date=t0_5)
self.assertEqual(test_eval.status(date=t1_5), Status.FAIL)
self.assertEqual(test_eval.status(date=t2_5), Status.PENDING)
self.assertEqual(test_eval.status(date=t3_5), Status.PASS)

qc = QualityControl(evaluations=[test_eval])

self.assertRaises(ValueError, qc.status, date=t0_5)
self.assertEqual(qc.status(date=t1_5), Status.FAIL)
self.assertEqual(qc.status(date=t2_5), Status.PENDING)
self.assertEqual(qc.status(date=t3_5), Status.PASS)


if __name__ == "__main__":
unittest.main()

0 comments on commit f6ef0c3

Please sign in to comment.