Skip to content

Commit

Permalink
serialization fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mekhlakapoor committed Oct 22, 2024
1 parent 22a1e29 commit af3de2f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
rig,
session,
subject,
quality_control
quality_control,
)

dummy_object = [
Expand All @@ -34,7 +34,7 @@
rig,
session,
subject,
quality_control
quality_control,
] # A temporary workaround to bypass "Imported but unused" error

INSTITUTE_NAME = "Allen Institute for Neural Dynamics"
Expand Down
6 changes: 3 additions & 3 deletions src/aind_data_schema/core/rig.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ class Rig(AindCoreModel):
notes: Optional[str] = Field(default=None, title="Notes")

@field_serializer("modalities", when_used="json")
def serialize_modalities(modalities: Set[Modality.ONE_OF]):
"""sort modalities by name when serializing to JSON"""
return sorted(modalities, key=lambda x: x.name)
def serialize_modalities(self, modalities: Set[Modality.ONE_OF]):
"""Dynamically serialize modalities based on their type."""
return sorted(modalities, key=lambda x: x.get("name") if isinstance(x, dict) else x.name)

@model_validator(mode="after")
def validate_cameras_other(self):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_rig.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
""" test Rig """

import unittest
import json
from datetime import date, datetime

from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.organizations import Organization
from pydantic import ValidationError
from pydantic_core import PydanticSerializationError

from aind_data_schema.components.devices import (
Calibration,
Expand Down Expand Up @@ -821,6 +823,30 @@ def test_rig_id_validator(self):
calibrations=[calibration],
)

def test_serialize_modalities(self):
"""Tests that modalities serializer can handle different types"""
expected_modalities = [{"name": "Extracellular electrophysiology", "abbreviation": "ecephys"}]
# Case 1: Modality is a class instance
rig_instance_modality = Rig.model_construct(
modalities=[Modality.ECEPHYS] # Example with a valid Modality instance
)
rig_json = rig_instance_modality.model_dump_json()
rig_data = json.loads(rig_json)
self.assertEqual(rig_data["modalities"], expected_modalities)

# Case 2: Modality is a dictionary when Rig is constructed from JSON
rig_dict_modality = Rig.model_construct(**rig_data)
rig_dict_json = rig_dict_modality.model_dump_json()
rig_dict_data = json.loads(rig_dict_json)
self.assertEqual(rig_dict_data["modalities"], expected_modalities)

# Case 3: Modality is an unknown type
with self.assertRaises(PydanticSerializationError) as context:
rig_unknown_modality = Rig.model_construct(modalities={"UnknownModality"})

rig_unknown_modality.model_dump_json()
self.assertIn("Error calling function `serialize_modalities`", str(context.exception))


if __name__ == "__main__":
unittest.main()

0 comments on commit af3de2f

Please sign in to comment.