diff --git a/src/aind_data_schema/core/metadata.py b/src/aind_data_schema/core/metadata.py index 9d086545..e428da66 100644 --- a/src/aind_data_schema/core/metadata.py +++ b/src/aind_data_schema/core/metadata.py @@ -3,7 +3,7 @@ import inspect import json import logging -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import Dict, List, Literal, Optional, get_args from uuid import UUID, uuid4 @@ -16,11 +16,12 @@ SkipValidation, ValidationError, ValidationInfo, + field_serializer, field_validator, model_validator, ) -from aind_data_schema.base import AindCoreModel, is_dict_corrupt +from aind_data_schema.base import AindCoreModel, is_dict_corrupt, AwareDatetimeWithDefault from aind_data_schema.core.acquisition import Acquisition from aind_data_schema.core.data_description import DataDescription from aind_data_schema.core.instrument import Instrument @@ -83,15 +84,13 @@ class Metadata(AindCoreModel): description="Name of the data asset.", title="Data Asset Name", ) - # We'll set created and last_modified defaults using the root_validator - # to ensure they're synced on creation - created: datetime = Field( - default_factory=datetime.utcnow, + created: AwareDatetimeWithDefault = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), title="Created", description="The utc date and time the data asset created.", ) - last_modified: datetime = Field( - default_factory=datetime.utcnow, + last_modified: AwareDatetimeWithDefault = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), title="Last Modified", description="The utc date and time that the data asset was last modified.", ) @@ -157,6 +156,16 @@ def validate_core_fields(cls, value, info: ValidationInfo): core_model = value return core_model + @field_validator("last_modified", mode="after") + def validate_last_modified(cls, value, info: ValidationInfo): + """Convert last_modified field to UTC from other timezones""" + return value.astimezone(timezone.utc) + + @field_serializer("last_modified") + def serialize_last_modified(value) -> str: + """Serialize last_modified field""" + return value.isoformat().replace("+00:00", "Z") + @model_validator(mode="after") def validate_metadata(self): """Validator for metadata""" diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 78f5428d..506ea3d9 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -5,6 +5,7 @@ import unittest from datetime import datetime, time, timezone from unittest.mock import MagicMock, call, patch +import uuid from aind_data_schema_models.modalities import Modality from aind_data_schema_models.organizations import Organization @@ -536,6 +537,51 @@ def test_create_from_core_jsons_corrupt(self, mock_is_dict_corrupt: MagicMock, m any_order=True, ) + def test_last_modified(self): + """Test that the last_modified field enforces timezones""" + m = Metadata.model_construct( + name="name", + location="location", + id=uuid.uuid4(), + ) + m_dict = m.model_dump(by_alias=True) + + # Test that naive datetime is coerced to timezone-aware datetime + date = "2022-11-22T08:43:00" + date_with_timezone = datetime.fromisoformat(date).astimezone() + m_dict["last_modified"] = "2022-11-22T08:43:00" + m2 = Metadata(**m_dict) + self.assertIsNotNone(m2) + self.assertEqual(m2.last_modified, date_with_timezone) + + # Also check that last_modified is now in UTC + self.assertEqual(m2.last_modified.tzinfo, timezone.utc) + + # Test that timezone-aware datetime is not coerced + date_minus = "2022-11-22T08:43:00-07:00" + m_dict["last_modified"] = date_minus + m3 = Metadata(**m_dict) + self.assertIsNotNone(m3) + self.assertEqual(m3.last_modified, datetime.fromisoformat(date_minus)) + + # Test that UTC datetime is not coerced + date_utc = "2022-11-22T08:43:00+00:00" + m_dict["last_modified"] = date_utc + m4 = Metadata(**m_dict) + self.assertIsNotNone(m4) + self.assertEqual(m4.last_modified, datetime.fromisoformat(date_utc)) + + def roundtrip_lm(model): + """Helper function to roundtrip last_modified field""" + model_json = model.model_dump_json(by_alias=True) + model_dict = json.loads(model_json) + return model_dict["last_modified"] + + # Test that the output looks right + self.assertEqual(m.last_modified.isoformat().replace("+00:00", "Z"), roundtrip_lm(m)) + self.assertEqual("2022-11-22T15:43:00Z", roundtrip_lm(m3)) + self.assertEqual("2022-11-22T08:43:00Z", roundtrip_lm(m4)) + if __name__ == "__main__": unittest.main()