Skip to content

Commit

Permalink
feat: validate fieldnames in AindGeneric (#1134)
Browse files Browse the repository at this point in the history
* feat: docdb util to check corrupt keys

* feat: validate fieldnames in AindGeneric

* refactor: move is_dict_corrupt to base module

* fix: ValidationError in tests
  • Loading branch information
helen-m-lin authored Nov 1, 2024
1 parent 07cb5b2 commit 7b2a2a4
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 3 deletions.
48 changes: 47 additions & 1 deletion src/aind_data_schema/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" generic base class with supporting validators and fields for basic AIND schema """

import json
import re
from pathlib import Path
from typing import Any, Generic, Optional, TypeVar
Expand All @@ -14,6 +15,7 @@
ValidationError,
ValidatorFunctionWrapHandler,
create_model,
model_validator,
)
from pydantic.functional_validators import WrapValidator
from typing_extensions import Annotated
Expand All @@ -31,13 +33,57 @@ def _coerce_naive_datetime(v: Any, handler: ValidatorFunctionWrapHandler) -> Awa
AwareDatetimeWithDefault = Annotated[AwareDatetime, WrapValidator(_coerce_naive_datetime)]


def is_dict_corrupt(input_dict: dict) -> bool:
"""
Checks that dictionary keys, included nested keys, do not contain
forbidden characters ("$" and ".").
Parameters
----------
input_dict : dict
Returns
-------
bool
True if input_dict is not a dict, or if any keys contain
forbidden characters. False otherwise.
"""

def has_corrupt_keys(input) -> bool:
"""Recursively checks nested dictionaries and lists"""
if isinstance(input, dict):
for key, value in input.items():
if "$" in key or "." in key:
return True
elif has_corrupt_keys(value):
return True
elif isinstance(input, list):
for item in input:
if has_corrupt_keys(item):
return True
return False

# Top-level input must be a dictionary
if not isinstance(input_dict, dict):
return True
return has_corrupt_keys(input_dict)


class AindGeneric(BaseModel, extra="allow"):
"""Base class for generic types that can be used in AIND schema"""

# extra="allow" is needed because BaseModel by default drops extra parameters.
# Alternatively, consider using 'SerializeAsAny' once this issue is resolved
# https://github.com/pydantic/pydantic/issues/6423
pass

@model_validator(mode="after")
def validate_fieldnames(self):
"""Ensure that field names do not contain forbidden characters"""
model_dict = json.loads(self.model_dump_json(by_alias=True))
if is_dict_corrupt(model_dict):
raise ValueError("Field names cannot contain '.' or '$'")
return self


AindGenericType = TypeVar("AindGenericType", bound=AindGeneric)
Expand Down
64 changes: 62 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
""" tests for Subject """

import json
import unittest
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock, call, mock_open, patch

from pydantic import create_model
from pydantic import ValidationError, create_model

from aind_data_schema.base import AwareDatetimeWithDefault
from aind_data_schema.base import AindGeneric, AwareDatetimeWithDefault, is_dict_corrupt
from aind_data_schema.core.subject import Subject


Expand Down Expand Up @@ -55,6 +56,65 @@ def test_aware_datetime_with_setting(self):
expected_json = '{"dt":"2020-10-10T01:02:03Z"}'
self.assertEqual(expected_json, model_instance.model_dump_json())

def test_is_dict_corrupt(self):
"""Tests is_dict_corrupt method"""
good_contents = [
{"a": 1, "b": {"c": 2, "d": 3}},
{"a": 1, "b": {"c": 2, "d": 3}, "e": ["f", "g"]},
{"a": 1, "b": {"c": 2, "d": 3}, "e": ["f.valid", "g"]},
{"a": 1, "b": {"c": {"d": 2}, "e": 3}},
{"a": 1, "b": [{"c": 2}, {"d": 3}], "e": 4},
]
bad_contents = [
{"a.1": 1, "b": {"c": 2, "d": 3}},
{"a": 1, "b": {"c": 2, "$d": 3}},
{"a": 1, "b": {"c": {"d": 2}, "$e": 3}},
{"a": 1, "b": {"c": 2, "d": 3, "e.csv": 4}},
{"a": 1, "b": [{"c": 2}, {"d.csv": 3}], "e": 4},
]
invalid_types = [
json.dumps({"a": 1, "b": {"c": 2, "d": 3}}),
[{"a": 1}, {"b": {"c": 2, "d": 3}}],
1,
None,
]
for contents in good_contents:
with self.subTest(contents=contents):
self.assertFalse(is_dict_corrupt(contents))
for contents in bad_contents:
with self.subTest(contents=contents):
self.assertTrue(is_dict_corrupt(contents))
for contents in invalid_types:
with self.subTest(contents=contents):
self.assertTrue(is_dict_corrupt(contents))

def test_aind_generic_constructor(self):
"""Tests default constructor for AindGeneric"""
model = AindGeneric()
self.assertEqual("{}", model.model_dump_json())

params = {"foo": "bar"}
model = AindGeneric(**params)
self.assertEqual('{"foo":"bar"}', model.model_dump_json())

def test_aind_generic_validate_fieldnames(self):
"""Tests that fieldnames are validated in AindGeneric"""
expected_error = (
"1 validation error for AindGeneric\n"
" Value error, Field names cannot contain '.' or '$' "
)
invalid_params = [
{"$foo": "bar"},
{"foo": {"foo.name": "bar"}},
]
for params in invalid_params:
with self.assertRaises(ValidationError) as e:
AindGeneric(**params)
self.assertIn(expected_error, repr(e.exception))
with self.assertRaises(ValidationError) as e:
AindGeneric.model_validate(params)
self.assertIn(expected_error, repr(e.exception))


if __name__ == "__main__":
unittest.main()

0 comments on commit 7b2a2a4

Please sign in to comment.