diff --git a/STACpopulator/extensions/base.py b/STACpopulator/extensions/base.py index 56082a1..ca74821 100644 --- a/STACpopulator/extensions/base.py +++ b/STACpopulator/extensions/base.py @@ -22,6 +22,7 @@ """ +from __future__ import annotations from datetime import datetime import json @@ -54,7 +55,7 @@ LOGGER = logging.getLogger(__name__) -class DataModelHelper(BaseModel): +class ExtensionHelper(BaseModel): """Base class for dataset properties going into the catalog. Subclass this with attributes. @@ -103,7 +104,51 @@ def apply(self, item, add_if_missing=False): return item -class THREDDSCatalogDataModel(BaseModel): +class BaseSTAC(BaseModel): + """Base class for STAC item data models.""" + # STAC item properties + geometry: AnyGeometry | None + bbox: list[float] + start_datetime: datetime + end_datetime: datetime + + extensions: list = [] + + model_config = ConfigDict(populate_by_name=True, extra="ignore", arbitrary_types_allowed=True) + + @property + def uid(self) -> str: + """Return a unique ID. When subclassing, use a combination of properties uniquely identifying a dataset.""" + # TODO: Should this be an abstract method? + import uuid + return str(uuid.uuid4()) + + def stac_item(self) -> "pystac.Item": + """Create a STAC item and add extensions.""" + item = pystac.Item( + id=self.uid, + geometry=self.geometry.model_dump(), + bbox=self.bbox, + properties={ + "start_datetime": str(self.start_datetime), + "end_datetime": str(self.end_datetime), + }, + datetime=None, + ) + + # Add extensions + for ext in self.extensions: + getattr(self, ext).apply(item) + + try: + item.validate() + except STACValidationError as e: + raise Exception("Failed to validate STAC item") from e + + return json.loads(json.dumps(item.to_dict())) + + +class THREDDSCatalogDataModel(BaseSTAC): """Base class ingesting attributes loaded by `THREDDSLoader` and creating a STAC item. This is meant to be subclassed for each extension. @@ -112,18 +157,11 @@ class THREDDSCatalogDataModel(BaseModel): - pydantic validation using type hints, and - json schema validation. """ - - # STAC item properties - geometry: GeoJSONPolygon - bbox: list[float] - start_datetime: datetime - end_datetime: datetime - # Data from loader data: dict # Extensions classes - properties: DataModelHelper + properties: ExtensionHelper datacube: DataCubeHelper thredds: THREDDSHelper @@ -165,37 +203,6 @@ def thredds_helper(cls, data): data["thredds"] = THREDDSHelper(data['data']["access_urls"]) return data - @property - def uid(self) -> str: - """Return a unique ID. When subclassing, use a combination of properties uniquely identifying a dataset.""" - # TODO: Should this be an abstract method? - import uuid - return str(uuid.uuid4()) - - def stac_item(self) -> "pystac.Item": - """Create a STAC item and add extensions.""" - item = pystac.Item( - id=self.uid, - geometry=self.geometry.model_dump(), - bbox=self.bbox, - properties={ - "start_datetime": str(self.start_datetime), - "end_datetime": str(self.end_datetime), - }, - datetime=None, - ) - - # Add extensions - for ext in self.extensions: - getattr(self, ext).apply(item) - - try: - item.validate() - except STACValidationError as e: - raise Exception("Failed to validate STAC item") from e - - return json.loads(json.dumps(item.to_dict())) - def metacls_extension(name, schema_uri): """Create an extension class dynamically from the properties.""" diff --git a/STACpopulator/extensions/cordex6.py b/STACpopulator/extensions/cordex6.py index 12eed34..e16e14b 100644 --- a/STACpopulator/extensions/cordex6.py +++ b/STACpopulator/extensions/cordex6.py @@ -7,11 +7,11 @@ from importlib import reload import STACpopulator.extensions.base reload(STACpopulator.extensions.base) -from STACpopulator.extensions.base import THREDDSCatalogDataModel, DataModelHelper +from STACpopulator.extensions.base import THREDDSCatalogDataModel, ExtensionHelper # This is generated using datamodel-codegen + manual edits -class CordexCmip6(DataModelHelper): +class CordexCmip6(ExtensionHelper): # Fields from schema activity_id: str = Field(..., alias='cordex6:activity_id') contact: str = Field(..., alias='cordex6:contact')