diff --git a/src/aind_data_schema/components/stimulus.py b/src/aind_data_schema/components/stimulus.py index c7dcedd42..81e4454eb 100644 --- a/src/aind_data_schema/components/stimulus.py +++ b/src/aind_data_schema/components/stimulus.py @@ -5,7 +5,7 @@ from typing import List, Literal, Optional from aind_data_schema_models.units import ConcentrationUnit, FrequencyUnit, PowerUnit, TimeUnit -from pydantic import Field +from pydantic import Field, model_validator from aind_data_schema.base import AindGeneric, AindGenericType, AindModel @@ -124,7 +124,7 @@ class AuditoryStimulation(AindModel): """Description of an auditory stimulus""" stimulus_type: Literal["Auditory Stimulation"] = "Auditory Stimulation" - sitmulus_name: str = Field(..., title="Stimulus name") + stimulus_name: str = Field(..., title="Stimulus name") sample_frequency: Decimal = Field(..., title="Sample frequency") amplitude_modulation_frequency: Optional[int] = Field(default=None, title="Amplitude modulation frequency") frequency_unit: FrequencyUnit = Field(default=FrequencyUnit.HZ, title="Tone frequency unit") @@ -133,3 +133,10 @@ class AuditoryStimulation(AindModel): bandpass_filter_type: Optional[FilterType] = Field(default=None, title="Bandpass filter type") bandpass_order: Optional[int] = Field(default=None, title="Bandpass order") notes: Optional[str] = Field(default=None, title="Notes") + + @model_validator(mode="before") + def correct_typo(cls, values): + """Correct 'sitmulus_name' typo.""" + if "sitmulus_name" in values: + values["stimulus_name"] = values.pop("sitmulus_name") + return values diff --git a/tests/test_components_stimulus.py b/tests/test_components_stimulus.py new file mode 100644 index 000000000..853005065 --- /dev/null +++ b/tests/test_components_stimulus.py @@ -0,0 +1,21 @@ +"""Test components.stimulus""" + +import unittest +from aind_data_schema.components.stimulus import AuditoryStimulation + + +class StimulusTests(unittest.TestCase): + """tests components.stimulus""" + + def test_typo(self): + """tests that the sitmulus typo is corrected""" + a = AuditoryStimulation( + stimulus_type="Auditory Stimulation", + stimulus_name="test", + sample_frequency=0.5, + ) + + a_dict = a.model_dump() + a_dict["sitmulus_name"] = a_dict.pop("stimulus_name") + + self.assertEqual(a.model_dump(), AuditoryStimulation(**a_dict).model_dump())