Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model schema (updated to use DataProcess) #1166

Merged
merged 20 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/aind_data_schema/core/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
""" schema describing an analysis model """

from typing import Any, List, Literal, Optional

from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.organizations import Organization
from aind_data_schema_models.system_architecture import ModelBackbone
from pydantic import Field

from aind_data_schema.base import AindCoreModel, AindGeneric, AindGenericType, AindModel
from aind_data_schema.components.devices import Software
from aind_data_schema.core.processing import DataProcess, ProcessName


class ModelArchitecture(AindModel):
"""Description of model architecture"""

backbone: ModelBackbone = Field(..., title="Backbone", description="Core network architecture")
software: List[Software] = Field(default=[], title="Software frameworks")
layers: Optional[int] = Field(default=None, title="Layers")
parameters: AindGenericType = Field(default=AindGeneric(), title="Parameters")
notes: Optional[str] = Field(default=None, title="Notes")


class PerformanceMetric(AindModel):
"""Description of a performance metric"""

name: str = Field(..., title="Metric name")
value: Any = Field(..., title="Metric value")


class ModelEvaluation(DataProcess):
"""Description of model evaluation"""

name: ProcessName = Field(ProcessName.MODEL_EVALUATION, title="Process name")
performance: List[PerformanceMetric] = Field(..., title="Evaluation performance")


class ModelTraining(DataProcess):
"""Description of model training"""

name: ProcessName = Field(ProcessName.MODEL_TRAINING, title="Process name")
train_performance: List[PerformanceMetric] = Field(
..., title="Training performance", description="Performance on training set"
)
test_performance: Optional[List[PerformanceMetric]] = Field(
default=None, title="Test performance", description="Performance on untrained data, evaluated during training"
tmchartrand marked this conversation as resolved.
Show resolved Hide resolved
)
test_data: Optional[str] = Field(
default=None, title="Test data", description="Path or cross-validation/split approach"
)


class Model(AindCoreModel):
"""Description of an analysis model"""

_DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/model.py"
describedBy: str = Field(_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL})
schema_version: Literal["0.0.1"] = Field("0.0.1")

name: str = Field(..., title="Name")
license: str = Field(..., title="License")
developer_full_name: Optional[List[str]] = Field(default=None, title="Name of developer")
developer_institution: Optional[Organization.ONE_OF] = Field(default=None, title="Institute where developed")
modality: Modality.ONE_OF = Field(..., title="Modality")
architecture: ModelArchitecture = Field(..., title="Model architecture")
intended_use: str = Field(..., title="Intended model use", description="Semantic description of intended use")
limitations: Optional[str] = Field(default=None, title="Model limitations")
pretrained_source_url: Optional[str] = Field(default=None, title="Pretrained source URL")
training: Optional[List[ModelTraining]] = Field(default=[], title="Training")
evaluations: Optional[List[ModelEvaluation]] = Field(default=[], title="Evaluations")
notes: Optional[str] = Field(default=None, title="Notes")
11 changes: 6 additions & 5 deletions src/aind_data_schema/core/processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" schema for processing """

from enum import Enum
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union

from aind_data_schema_models.process_names import ProcessName
from aind_data_schema_models.units import MemoryUnit, UnitlessUnit
Expand Down Expand Up @@ -57,15 +57,16 @@ class DataProcess(AindModel):
"""Description of a single processing step"""

name: ProcessName = Field(..., title="Name")
software_version: str = Field(..., description="Version of the software used", title="Version")
software_version: Optional[str] = Field(default=None, description="Version of the software used", title="Version")
start_date_time: AwareDatetimeWithDefault = Field(..., title="Start date time")
end_date_time: AwareDatetimeWithDefault = Field(..., title="End date time")
input_location: str = Field(..., description="Path to data inputs", title="Input location")
# allowing multiple input locations, to be replaced by CompositeData object in future
input_location: Union[str, List[str]] = Field(..., description="Path(s) to data inputs", title="Input location")
output_location: str = Field(..., description="Path to data outputs", title="Output location")
code_url: str = Field(..., description="Path to code repository", title="Code URL")
code_version: Optional[str] = Field(default=None, description="Version of the code", title="Code version")
parameters: AindGenericType = Field(..., title="Parameters")
outputs: AindGenericType = Field(AindGeneric(), description="Output parameters", title="Outputs")
parameters: AindGenericType = Field(default=AindGeneric(), title="Parameters")
outputs: AindGenericType = Field(default=AindGeneric(), description="Output parameters", title="Outputs")
notes: Optional[str] = Field(default=None, title="Notes", validate_default=True)
resources: Optional[ResourceUsage] = Field(default=None, title="Process resource usage")

Expand Down
91 changes: 91 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
""" tests for Model """

import datetime
import unittest

import pydantic
from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.organizations import Organization
from aind_data_schema_models.system_architecture import ModelBackbone

from aind_data_schema.components.devices import Software
from aind_data_schema.core.model import Model, ModelArchitecture, ModelEvaluation, ModelTraining, PerformanceMetric


class ModelTests(unittest.TestCase):
"""tests for model"""

def test_constructors(self):
"""try building model"""

with self.assertRaises(pydantic.ValidationError):
Model()

now = datetime.datetime.now()

m = Model(
name="2024_01_01_ResNet18_SmartSPIM.h5",
license="CC-BY-4.0",
developer_full_name=["Joe Schmoe"],
developer_institution=Organization.AIND,
modality=Modality.SPIM,
pretrained_source_url="url pretrained weights are from",
architecture=ModelArchitecture(
backbone=ModelBackbone.RESNET,
layers=18,
parameters={
"downsample": 1,
"input_shape": [14, 14, 26],
},
software=[
Software(
name="tensorflow",
version="2.11.0",
)
],
),
intended_use="Cell counting for 488 channel of SmartSPIM data",
limitations="Only trained on 488 channel",
training=[
ModelTraining(
input_location=["s3 path to eval 1", "s3 path to eval 2"],
output_location="s3 path to trained model asset",
code_url="url for training code repo",
start_date_time=now,
end_date_time=now,
train_performance=[
PerformanceMetric(name="precision", value=0.9),
PerformanceMetric(name="recall", value=0.85),
],
test_performance=[
PerformanceMetric(name="precision", value=0.8),
PerformanceMetric(name="recall", value=0.8),
],
test_data="4:1 train/test split",
parameters={
"learning_rate": 0.0001,
"batch_size": 32,
"augmentation": True,
},
notes="note on training data selection",
)
],
evaluations=[
ModelEvaluation(
input_location=["s3 path to eval 1", "s3 path to eval 2"],
output_location="s3 path (output asset or trained model asset if no output)",
code_url="url for evaluation code repo (or capsule?)",
start_date_time=now,
end_date_time=now,
performance=[PerformanceMetric(name="precision", value=0.8)],
)
],
)

Model.model_validate_json(m.model_dump_json())

self.assertIsNotNone(m)


if __name__ == "__main__":
unittest.main()
8 changes: 1 addition & 7 deletions tests/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ def test_constructors(self):
DataProcess(name="Other", notes="")

expected_exception = (
"8 validation errors for DataProcess\n"
"software_version\n"
" Field required [type=missing, input_value={'name': 'Other', 'notes': ''}, input_type=dict]\n"
f" For further information visit https://errors.pydantic.dev/{PYD_VERSION}/v/missing\n"
"6 validation errors for DataProcess\n"
"start_date_time\n"
" Field required [type=missing, input_value={'name': 'Other', 'notes': ''}, input_type=dict]\n"
f" For further information visit https://errors.pydantic.dev/{PYD_VERSION}/v/missing\n"
Expand All @@ -56,9 +53,6 @@ def test_constructors(self):
"code_url\n"
" Field required [type=missing, input_value={'name': 'Other', 'notes': ''}, input_type=dict]\n"
f" For further information visit https://errors.pydantic.dev/{PYD_VERSION}/v/missing\n"
"parameters\n"
" Field required [type=missing, input_value={'name': 'Other', 'notes': ''}, input_type=dict]\n"
f" For further information visit https://errors.pydantic.dev/{PYD_VERSION}/v/missing\n"
"notes\n"
" Value error, Notes cannot be empty if 'name' is Other. Describe the process name in the notes field."
" [type=value_error, input_value='', input_type=str]\n"
Expand Down
Loading