diff --git a/docs/source/conf.py b/docs/source/conf.py index 31d9a3bee..42a7ee173 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ rig, session, subject, - quality_control + quality_control, ) dummy_object = [ @@ -34,7 +34,7 @@ rig, session, subject, - quality_control + quality_control, ] # A temporary workaround to bypass "Imported but unused" error INSTITUTE_NAME = "Allen Institute for Neural Dynamics" diff --git a/src/aind_data_schema/__init__.py b/src/aind_data_schema/__init__.py index c02cc6a60..5d3d4791e 100755 --- a/src/aind_data_schema/__init__.py +++ b/src/aind_data_schema/__init__.py @@ -1,4 +1,4 @@ """ imports for AindModel subclasses """ -__version__ = "1.1.0" +__version__ = "1.1.1" diff --git a/src/aind_data_schema/core/rig.py b/src/aind_data_schema/core/rig.py index ef26f34bf..cc354406e 100644 --- a/src/aind_data_schema/core/rig.py +++ b/src/aind_data_schema/core/rig.py @@ -90,9 +90,9 @@ class Rig(AindCoreModel): notes: Optional[str] = Field(default=None, title="Notes") @field_serializer("modalities", when_used="json") - def serialize_modalities(modalities: Set[Modality.ONE_OF]): - """sort modalities by name when serializing to JSON""" - return sorted(modalities, key=lambda x: x.name) + def serialize_modalities(self, modalities: Set[Modality.ONE_OF]): + """Dynamically serialize modalities based on their type.""" + return sorted(modalities, key=lambda x: x.get("name") if isinstance(x, dict) else x.name) @model_validator(mode="after") def validate_cameras_other(self): diff --git a/tests/test_rig.py b/tests/test_rig.py index d1d86f454..32b243125 100644 --- a/tests/test_rig.py +++ b/tests/test_rig.py @@ -1,11 +1,13 @@ """ test Rig """ import unittest +import json from datetime import date, datetime from aind_data_schema_models.modalities import Modality from aind_data_schema_models.organizations import Organization from pydantic import ValidationError +from pydantic_core import PydanticSerializationError from aind_data_schema.components.devices import ( Calibration, @@ -821,6 +823,30 @@ def test_rig_id_validator(self): calibrations=[calibration], ) + def test_serialize_modalities(self): + """Tests that modalities serializer can handle different types""" + expected_modalities = [{"name": "Extracellular electrophysiology", "abbreviation": "ecephys"}] + # Case 1: Modality is a class instance + rig_instance_modality = Rig.model_construct( + modalities=[Modality.ECEPHYS] # Example with a valid Modality instance + ) + rig_json = rig_instance_modality.model_dump_json() + rig_data = json.loads(rig_json) + self.assertEqual(rig_data["modalities"], expected_modalities) + + # Case 2: Modality is a dictionary when Rig is constructed from JSON + rig_dict_modality = Rig.model_construct(**rig_data) + rig_dict_json = rig_dict_modality.model_dump_json() + rig_dict_data = json.loads(rig_dict_json) + self.assertEqual(rig_dict_data["modalities"], expected_modalities) + + # Case 3: Modality is an unknown type + with self.assertRaises(PydanticSerializationError) as context: + rig_unknown_modality = Rig.model_construct(modalities={"UnknownModality"}) + + rig_unknown_modality.model_dump_json() + self.assertIn("Error calling function `serialize_modalities`", str(context.exception)) + if __name__ == "__main__": unittest.main()