diff --git a/.github/workflows/test-build-push.yml b/.github/workflows/test-build-push.yml index 724a861d..caeb5a83 100644 --- a/.github/workflows/test-build-push.yml +++ b/.github/workflows/test-build-push.yml @@ -33,6 +33,7 @@ jobs: pydantic=2.1 pymp-pypi=0.4.5 pyproj=3.3 + rasterio=1.3 rich=12.0 ruamel_yaml=0.15 scipy=1.5 diff --git a/CHANGELOG.md b/CHANGELOG.md index 42591408..2855ad95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,16 @@ -# [Unreleased](https://github.com/isce-framework/dolphin/compare/v0.8.0...main) +# [Unreleased](https://github.com/isce-framework/dolphin/compare/v0.9.0...main) + +# [v0.9.0](https://github.com/isce-framework/dolphin/compare/v0.8.0...v0.9.0) + +**Added** +- `DatasetReader` and `StackReader` protocols for reading in data from different sources + - `DatasetReader` is for reading in a single dataset, like one raster image. + - `StackReader` is for reading in a stack of datasets, like a stack of SLCs. + - Implementations of these have been done for flat binary files (`BinaryReader`), HDF5 files (`HDF5Reader`), and GDAL rasters (`RasterReader`). + +**Changed** +- The `VRTStack` no longer has an `.iter_blocks` method + - This has been replaced with creating an `EagerLoader` directly and passing it to the `reader` argument # [v0.8.0](https://github.com/isce-framework/dolphin/compare/v0.7.0...v0.8.0) diff --git a/conda-env.yml b/conda-env.yml index a905dc3a..2d50738a 100644 --- a/conda-env.yml +++ b/conda-env.yml @@ -15,6 +15,7 @@ dependencies: - pydantic>=2.1 - pymp-pypi>=0.4.5 - pyproj>=3.3 + - rasterio>=1.3 - rich>=12.0 - ruamel.yaml>=0.15 - scipy>=1.5 # "scipy 0.16+ is required for linear algebra", numba. 1.5 is the oldest version that supports Python 3.7 diff --git a/mkdocs.yml b/mkdocs.yml index c0aa658d..74652e3c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,6 +13,7 @@ theme: plugins: - search # plugin suggestions from here: https://mkdocstrings.github.io/recipes/ +- autorefs - gen-files: scripts: - docs/gen_ref_pages.py diff --git a/pyproject.toml b/pyproject.toml index 91e056e4..14f51b69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,8 @@ plugins = ["pydantic.mypy"] [tool.pytest.ini_options] doctest_optionflags = "NORMALIZE_WHITESPACE NUMBER" -addopts = " --cov=dolphin -n auto --maxprocesses=8 --doctest-modules --randomly-seed=1234 --ignore=scripts --ignore=docs --ignore=data --ignore=pkgs" +# addopts = " --cov=dolphin -n auto --maxprocesses=8 --doctest-modules --randomly-seed=1234 --ignore=scripts --ignore=docs --ignore=data --ignore=pkgs" +addopts = " --doctest-modules --randomly-seed=1234 --ignore=scripts --ignore=docs --ignore=data --ignore=pkgs" filterwarnings = [ "error", # DeprecationWarning thrown in pkg_resources for older numba verions and llvmlite diff --git a/requirements.txt b/requirements.txt index e9a11477..87b06c7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ opera-utils>=0.1.5 pydantic>=2.1 pymp-pypi>=0.4.5 pyproj>=3.3 +rasterio>=1.3 rich>=12.0 ruamel_yaml>=0.15 scipy>=1.5 diff --git a/src/dolphin/_background.py b/src/dolphin/_background.py index 2d264954..adb917b1 100644 --- a/src/dolphin/_background.py +++ b/src/dolphin/_background.py @@ -14,6 +14,13 @@ _DEFAULT_TIMEOUT = 0.5 +__all__ = [ + "BackgroundWorker", + "BackgroundReader", + "BackgroundWriter", + "DummyProcessPoolExecutor", +] + class BackgroundWorker(abc.ABC): """Base class for doing work in a background thread. diff --git a/src/dolphin/_readers.py b/src/dolphin/_readers.py index b632e05c..1baf7ae9 100644 --- a/src/dolphin/_readers.py +++ b/src/dolphin/_readers.py @@ -1,21 +1,667 @@ from __future__ import annotations +import mmap +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from os import fspath from pathlib import Path -from typing import Generator, Optional, Sequence +from typing import ( + TYPE_CHECKING, + Generator, + Optional, + Protocol, + Sequence, + runtime_checkable, +) +import h5py import numpy as np +import rasterio as rio +from numpy.typing import ArrayLike from osgeo import gdal from dolphin import io, utils +from dolphin._background import _DEFAULT_TIMEOUT, BackgroundReader +from dolphin._blocks import iter_blocks from dolphin._dates import get_dates, sort_files_by_date from dolphin._types import Filename from dolphin.stack import logger +from dolphin.utils import progress -__all__ = ["VRTStack"] +__all__ = [ + "DatasetReader", + "BinaryReader", + "StackReader", + "BinaryStackReader", + "VRTStack", +] +if TYPE_CHECKING: + from builtins import ellipsis -class VRTStack: + Index = ellipsis | slice | int + + +@runtime_checkable +class DatasetReader(Protocol): + """An array-like interface for reading input datasets. + + `DatasetReader` defines the abstract interface that types must conform to in order + to be read by functions which iterate in blocks over the input data. + Such objects must export NumPy-like `dtype`, `shape`, and `ndim` attributes, + and must support NumPy-style slice-based indexing. + + Note that this protol allows objects to be passed to `dask.array.from_array` + which needs `.shape`, `.ndim`, `.dtype` and support numpy-style slicing. + """ + + dtype: np.dtype + """numpy.dtype : Data-type of the array's elements.""" # noqa: D403 + + shape: tuple[int, ...] + """tuple of int : Tuple of array dimensions.""" # noqa: D403 + + ndim: int + """int : Number of array dimensions.""" # noqa: D403 + + masked: bool = False + """bool : If True, return a masked array with the nodata values masked out.""" + + def __getitem__(self, key: tuple[Index, ...], /) -> ArrayLike: + """Read a block of data.""" + ... + + +@runtime_checkable +class StackReader(DatasetReader, Protocol): + """An array-like interface for reading a 3D stack of input datasets. + + `StackReader` defines the abstract interface that types must conform to in order + to be valid inputs to be read in functions like [dolphin.ps.create_ps][]. + It is a specialization of [DatasetReader][] that requires a 3D shape. + """ + + ndim: int = 3 + """int : Number of array dimensions.""" # noqa: D403 + + shape: tuple[int, int, int] + """tuple of int : Tuple of array dimensions.""" + + def __len__(self) -> int: + """int : Number of images in the stack.""" + return self.shape[0] + + +def _mask_array(arr: np.ndarray, nodata_value: float | None) -> np.ma.MaskedArray: + """Mask an array based on a nodata value.""" + if np.isnan(nodata_value): + return np.ma.masked_invalid(arr) + return np.ma.masked_equal(arr, nodata_value) + + +@dataclass +class BinaryReader(DatasetReader): + """A flat binary file for storing array data. + + See Also + -------- + HDF5Dataset + RasterReader + + Notes + ----- + This class does not store an open file object. Instead, the file is opened on-demand + for reading or writing and closed immediately after each read/write operation. This + allows multiple spawned processes to write to the file in coordination (as long as a + suitable mutex is used to guard file access.) + """ + + filepath: Path + """pathlib.Path : The file path.""" # noqa: D403 + + shape: tuple[int, ...] + """tuple of int : Tuple of array dimensions.""" # noqa: D403 + + dtype: np.dtype + """numpy.dtype : Data-type of the array's elements.""" # noqa: D403 + + nodata: Optional[float] = None + """Optional[float] : Value to use for nodata pixels.""" + + def __post_init__(self): + self.filepath = Path(self.filepath) + if not self.filepath.exists(): + raise FileNotFoundError(f"File {self.filepath} does not exist.") + self.dtype = np.dtype(self.dtype) + + @property + def ndim(self) -> int: # type: ignore[override] + """int : Number of array dimensions.""" # noqa: D403 + return len(self.shape) + + def __getitem__(self, key: tuple[Index, ...], /) -> np.ndarray: + with self.filepath.open("rb") as f: + # Memory-map the entire file. + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + # In order to safely close the memory-map, there can't be any dangling + # references to it, so we return a copy (not a view) of the requested + # data and decref the array object. + arr = np.frombuffer(mm, dtype=self.dtype).reshape(self.shape) + data = arr[key].copy() + del arr + return _mask_array(data, self.nodata) if self.masked else data + + def __array__(self) -> np.ndarray: + return self[:,] + + @classmethod + def from_gdal( + cls, filename: Filename, band: int = 1, nodata: Optional[float] = None + ) -> BinaryReader: + """Create a BinaryReader from a GDAL-readable file. + + Parameters + ---------- + filename : Filename + Path to the file to read. + band : int, optional + Band to read from the file, by default 1 + nodata : float, optional + Value to use for nodata pixels, by default None + If None passed, will search for a nodata value in the file. + + Returns + ------- + BinaryReader + The BinaryReader object. + """ + with rio.open(filename) as src: + dtype = src.dtypes[band - 1] + shape = src.shape + nodata = src.nodatavals[band - 1] + return cls( + Path(filename), + shape=shape, + dtype=dtype, + nodata=nodata or nodata, + ) + + +@dataclass +class HDF5Reader(DatasetReader): + """A Dataset in an HDF5 file. + + Attributes + ---------- + filepath : pathlib.Path | str + Location of HDF5 file. + dset_name : str + Path to the dataset within the file. + chunks : tuple[int, ...], optional + Chunk shape of the dataset, or None if file is unchunked. + keep_open : bool, optional (default False) + If True, keep the HDF5 file handle open for faster reading. + + + See Also + -------- + BinaryReader + RasterReader + + Notes + ----- + If `keep_open=True`, this class does not store an open file object. + Otherwise, the file is opened on-demand for reading or writing and closed + immediately after each read/write operation. + If passing the `HDF5Reader` to multiple spawned processes, it is recommended + to set `keep_open=False` . + """ + + filepath: Path + """pathlib.Path : The file path.""" + + dset_name: str + """str : The path to the dataset within the file.""" + + nodata: Optional[float] = None + """Optional[float] : Value to use for nodata pixels. + + If None, looks for `_FillValue` or `missing_value` attributes on the dataset. + """ + + keep_open: bool = False + """bool : If True, keep the HDF5 file handle open for faster reading.""" + + def __post_init__(self): + filepath = Path(self.filepath) + + hf = h5py.File(filepath, "r") + dset = hf[self.dset_name] + self.shape = dset.shape + self.dtype = dset.dtype + self.chunks = dset.chunks + if self.nodata is None: + self.nodata = dset.attrs.get("_FillValue", None) + if self.nodata is None: + self.nodata = dset.attrs.get("missing_value", None) + if self.keep_open: + self._hf = hf + self._dset = dset + else: + hf.close() + + @property + def ndim(self) -> int: # type: ignore[override] + """int : Number of array dimensions.""" + return len(self.shape) + + def __array__(self) -> np.ndarray: + return self[:,] + + def __getitem__(self, key: tuple[Index, ...], /) -> np.ndarray: + if self.keep_open: + data = self._dset[key] + else: + with h5py.File(self.filepath, "r") as f: + data = f[self.dset_name][key] + return _mask_array(data, self.nodata) if self.masked else data + + +def _ensure_slices(rows: Index, cols: Index) -> tuple[slice, slice]: + def _parse(key: Index): + if isinstance(key, int): + return slice(key, key + 1) + elif key is ...: + return slice(None) + else: + return key + + return _parse(rows), _parse(cols) + + +@dataclass +class RasterReader(DatasetReader): + """A single raster band of a GDAL-compatible dataset. + + See Also + -------- + BinaryReader + HDF5 + + Notes + ----- + If `keep_open=True`, this class does not store an open file object. + Otherwise, the file is opened on-demand for reading or writing and closed + immediately after each read/write operation. + If passing the `RasterReader` to multiple spawned processes, it is recommended + to set `keep_open=False` . + """ + + filepath: Filename + """Filename : The file path.""" + + band: int + """int : Band index (1-based).""" + + driver: str + """str : Raster format driver name.""" + + crs: rio.crs.CRS + """rio.crs.CRS : The dataset's coordinate reference system.""" + + transform: rio.transform.Affine + """ + rasterio.transform.Affine : The dataset's georeferencing transformation matrix. + + This transform maps pixel row/column coordinates to coordinates in the dataset's + coordinate reference system. + """ + + shape: tuple[int, int] + dtype: np.dtype + + nodata: Optional[float] = None + """Optional[float] : Value to use for nodata pixels.""" + + keep_open: bool = False + """bool : If True, keep the rasterio file handle open for faster reading.""" + + chunks: Optional[tuple[int, int]] = None + """Optional[tuple[int, int]] : Chunk shape of the dataset, or None if file is unchunked.""" + + @classmethod + def from_file( + cls, + filepath: Filename, + band: int = 1, + nodata: Optional[float] = None, + keep_open: bool = False, + **options, + ) -> RasterReader: + with rio.open(filepath, "r", **options) as src: + shape = (src.height, src.width) + dtype = np.dtype(src.dtypes[band - 1]) + driver = src.driver + crs = src.crs + nodata = nodata or src.nodatavals[band - 1] + transform = src.transform + chunks = src.block_shapes[band - 1] + + return cls( + filepath=filepath, + band=band, + driver=driver, + crs=crs, + transform=transform, + shape=shape, + dtype=dtype, + nodata=nodata, + keep_open=keep_open, + chunks=chunks, + ) + + def __post_init__(self): + if self.keep_open: + self._src = rio.open(self.filepath, "r") + + @property + def ndim(self) -> int: # type: ignore[override] + """int : Number of array dimensions.""" + return 2 + + def __array__(self) -> np.ndarray: + return self[:, :] + + def __getitem__(self, key: tuple[Index, ...], /) -> np.ndarray: + import rasterio.windows + + if key is ... or key == (): + key = (slice(None), slice(None)) + + if not isinstance(key, tuple): + raise ValueError("Index must be a tuple of slices or integers.") + + r_slice, c_slice = _ensure_slices(*key[-2:]) + window = rasterio.windows.Window.from_slices( + r_slice, + c_slice, + height=self.shape[0], + width=self.shape[1], + ) + if self.keep_open: + out = self._src.read(self.band, window=window) + + with rio.open(self.filepath) as src: + out = src.read(self.band, window=window) + out_masked = _mask_array(out, self.nodata) if self.masked else out + # Note that Rasterio doesn't use the `step` of a slice, so we need to + # manually slice the output array. + r_step, c_step = r_slice.step or 1, c_slice.step or 1 + return out_masked[::r_step, ::c_step].squeeze() + + +def _read_3d( + key: tuple[Index, ...], readers: Sequence[DatasetReader], num_threads: int = 1 +): + # Check that it's a tuple of slices + if not isinstance(key, tuple): + raise ValueError("Index must be a tuple of slices.") + if len(key) not in (1, 3): + raise ValueError("Index must be a tuple of 1 or 3 slices.") + # If only the band is passed (e.g. stack[0]), convert to (0, :, :) + if len(key) == 1: + key = (key[0], slice(None), slice(None)) + # unpack the slices + bands, rows, cols = key + # convert the rows/cols to slices + r_slice, c_slice = _ensure_slices(rows, cols) + + if isinstance(bands, slice): + # convert the bands to -1-indexed list + total_num_bands = len(readers) + band_idxs = list(range(*bands.indices(total_num_bands))) + elif isinstance(bands, int): + band_idxs = [bands] + else: + raise ValueError("Band index must be an integer or slice.") + + # Get only the bands we need + if num_threads == 1: + out = np.stack([readers[i][r_slice, c_slice] for i in band_idxs], axis=0) + else: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + results = executor.map(lambda i: readers[i][r_slice, c_slice], band_idxs) + out = np.stack(list(results), axis=0) + + # TODO: Do i want a "keep_dims" option to not collapse singleton dimensions? + return np.squeeze(out) + + +@dataclass +class BaseStackReader(StackReader): + """Base class for stack readers.""" + + file_list: Sequence[Filename] + readers: Sequence[DatasetReader] + num_threads: int = 1 + nodata: Optional[float] = None + + def __getitem__(self, key: tuple[Index, ...], /) -> np.ndarray: + return _read_3d(key, self.readers, num_threads=self.num_threads) + + @property + def shape_2d(self): + return self.readers[0].shape + + @property + def shape(self): + return (len(self.file_list), *self.shape_2d) + + @property + def dtype(self): + return self.readers[0].dtype + + +@dataclass +class BinaryStackReader(BaseStackReader): + @classmethod + def from_file_list( + cls, file_list: Sequence[Filename], shape_2d: tuple[int, int], dtype: np.dtype + ) -> BinaryStackReader: + """Create a BinaryStackReader from a list of files. + + Parameters + ---------- + file_list : Sequence[Filename] + List of paths to the files to read. + shape_2d : tuple[int, int] + Shape of each file. + dtype : np.dtype + Data type of each file. + + Returns + ------- + BinaryStackReader + The BinaryStackReader object. + """ + readers = [ + BinaryReader(Path(f), shape=shape_2d, dtype=dtype) for f in file_list + ] + return cls(file_list=file_list, readers=readers, num_threads=1) + + @classmethod + def from_gdal( + cls, + file_list: Sequence[Filename], + band: int = 1, + num_threads: int = 1, + nodata: Optional[float] = None, + ) -> BinaryStackReader: + """Create a BinaryStackReader from a list of GDAL-readable files. + + Parameters + ---------- + file_list : Sequence[Filename] + List of paths to the files to read. + band : int, optional + Band to read from the file, by default 1 + num_threads : int, optional (default 1) + Number of threads to use for reading. + nodata : float, optional + Manually set value to use for nodata pixels, by default None + If None passed, will search for a nodata value in the file. + + Returns + ------- + BinaryStackReader + The BinaryStackReader object. + """ + readers = [] + dtypes = set() + shapes = set() + for f in file_list: + with rio.open(f) as src: + dtypes.add(src.dtypes[band - 1]) + shapes.add(src.shape) + if len(dtypes) > 1: + raise ValueError("All files must have the same data type.") + if len(shapes) > 1: + raise ValueError("All files must have the same shape.") + readers.append(BinaryReader.from_gdal(f, band=band)) + return cls( + file_list=file_list, + readers=readers, + num_threads=num_threads, + nodata=nodata, + ) + + +@dataclass +class HDF5StackReader(BaseStackReader): + """A stack of datasets in an HDF5 file. + + See Also + -------- + BinaryStackReader + StackReader + + Notes + ----- + If `keep_open=True`, this class stores an open file object. + Otherwise, the file is opened on-demand for reading or writing and closed + immediately after each read/write operation. + If passing the `HDF5StackReader` to multiple spawned processes, it is recommended + to set `keep_open=False`. + """ + + @classmethod + def from_file_list( + cls, + file_list: Sequence[Filename], + dset_names: str | Sequence[str], + keep_open: bool = False, + num_threads: int = 1, + nodata: Optional[float] = None, + ) -> HDF5StackReader: + """Create a HDF5StackReader from a list of files. + + Parameters + ---------- + file_list : Sequence[Filename] + List of paths to the files to read. + dset_names : str | Sequence[str] + Name of the dataset to read from each file. + If a single string, will be used for all files. + keep_open : bool, optional (default False) + If True, keep the HDF5 file handles open for faster reading. + num_threads : int, optional (default 1) + Number of threads to use for reading. + nodata : float, optional + Manually set value to use for nodata pixels, by default None + If None passed, will search for a nodata value in the file. + + Returns + ------- + HDF5StackReader + The HDF5StackReader object. + """ + if isinstance(dset_names, str): + dset_names = [dset_names] * len(file_list) + + readers = [ + HDF5Reader(Path(f), dset_name=dn, keep_open=keep_open, nodata=nodata) + for (f, dn) in zip(file_list, dset_names) + ] + # Check if nodata values were found in the files + nds = set([r.nodata for r in readers]) + if len(nds) == 1: + nodata = nds.pop() + + return cls(file_list, readers, num_threads=num_threads, nodata=nodata) + + +@dataclass +class RasterStackReader(BaseStackReader): + """A stack of datasets for any GDAL-readable rasters. + + See Also + -------- + BinaryStackReader + HDF5StackReader + + Notes + ----- + If `keep_open=True`, this class stores an open file object. + Otherwise, the file is opened on-demand for reading or writing and closed + immediately after each read/write operation. + """ + + @classmethod + def from_file_list( + cls, + file_list: Sequence[Filename], + bands: int | Sequence[int] = 1, + keep_open: bool = False, + num_threads: int = 1, + nodata: Optional[float] = None, + ) -> RasterStackReader: + """Create a RasterStackReader from a list of files. + + Parameters + ---------- + file_list : Sequence[Filename] + List of paths to the files to read. + bands : int | Sequence[int] + Band to read from each file. + If a single int, will be used for all files. + Default = 1. + keep_open : bool, optional (default False) + If True, keep the rasterio file handles open for faster reading. + num_threads : int, optional (default 1) + Number of threads to use for reading. + nodata : float, optional + Manually set value to use for nodata pixels, by default None + + Returns + ------- + RasterStackReader + The RasterStackReader object. + """ + if isinstance(bands, int): + bands = [bands] * len(file_list) + + readers = [ + RasterReader.from_file(f, band=b, keep_open=keep_open) + for (f, b) in zip(file_list, bands) + ] + # Check if nodata values were found in the files + nds = set([r.nodata for r in readers]) + if len(nds) == 1: + nodata = nds.pop() + return cls(file_list, readers, num_threads=num_threads, nodata=nodata) + + +class VRTStack(StackReader): """Class for creating a virtual stack of raster files. Attributes @@ -55,6 +701,7 @@ def __init__( write_file: bool = True, fail_on_overwrite: bool = False, skip_size_check: bool = False, + num_threads: int = 1, ): if Path(outfile).exists() and write_file: if fail_on_overwrite: @@ -66,12 +713,14 @@ def __init__( else: logger.info(f"Overwriting {outfile}") - files: list[Filename] = [Path(f) for f in file_list] + # files: list[Filename] = [Path(f) for f in file_list] self._use_abs_path = use_abs_path if use_abs_path: - files = [utils._resolve_gdal_path(p) for p in files] + files = [utils._resolve_gdal_path(p) for p in file_list] + else: + files = list(file_list) # Extract the date/datetimes from the filenames - dates = [get_dates(f, fmt=file_date_fmt) for f in files] + dates = [get_dates(f, fmt=file_date_fmt) for f in file_list] if sort_files: files, dates = sort_files_by_date( # type: ignore files, file_date_fmt=file_date_fmt @@ -80,6 +729,7 @@ def __init__( # Save the attributes self.file_list = files self.dates = dates + self.num_threads = num_threads self.outfile = Path(outfile).resolve() # Assumes that all files use the same subdataset (if NetCDF) @@ -192,51 +842,6 @@ def from_vrt_file(cls, vrt_file, new_outfile=None, **kwargs): **kwargs, ) - def iter_blocks( - self, - overlaps: tuple[int, int] = (0, 0), - block_shape: tuple[int, int] = (512, 512), - skip_empty: bool = True, - nodata_mask: Optional[np.ndarray] = None, - show_progress: bool = True, - ) -> Generator[tuple[np.ndarray, tuple[slice, slice]], None, None]: - """Iterate over blocks of the stack. - - Loads all images for one window at a time into memory. - - Parameters - ---------- - overlaps : tuple[int, int], optional - Pixels to overlap each block by (rows, cols) - By default (0, 0) - block_shape : tuple[int, int], optional - 2D shape of blocks to load at a time. - Loads all dates/bands at a time with this shape. - skip_empty : bool, optional (default True) - Skip blocks that are entirely empty (all NaNs) - nodata_mask : bool, optional - Optional mask indicating nodata values. If provided, will skip - blocks that are entirely nodata. - 1s are the nodata values, 0s are valid data. - show_progress : bool, default=True - If true, displays a `rich` ProgressBar. - - Yields - ------ - tuple[np.ndarray, tuple[slice, slice]] - Iterator of (data, (slice(row_start, row_stop), slice(col_start, col_stop)) - - """ - self._loader = io.EagerLoader( - self.outfile, - block_shape=block_shape, - overlaps=overlaps, - nodata_mask=nodata_mask, - skip_empty=skip_empty, - show_progress=show_progress, - ) - yield from self._loader.iter_blocks() - @property def shape(self): """Get the 3D shape of the stack.""" @@ -258,8 +863,6 @@ def __eq__(self, other): and self.outfile == other.outfile ) - # To allow VRTStack to be passed to `dask.array.from_array`, we need: - # .shape, .ndim, .dtype and support numpy-style slicing. @property def ndim(self): return 3 @@ -281,15 +884,28 @@ def __getitem__(self, index): if n < 0: n = len(self) + n return self.read_stack(band=n + 1, rows=rows, cols=cols) + elif n is ...: + n = slice(None) bands = list(range(1, 1 + len(self)))[n] if len(bands) == len(self): # This will use gdal's ds.ReadAsRaster, no iteration needed data = self.read_stack(band=None, rows=rows, cols=cols) else: - data = np.stack( - [self.read_stack(band=i, rows=rows, cols=cols) for i in bands], axis=0 - ) + # Get only the bands we need + if self.num_threads == 1: + # out = np.stack([readers[i][r_slice, c_slice] for i in band_idxs], axis=0) + data = np.stack( + [self.read_stack(band=i, rows=rows, cols=cols) for i in bands], + axis=0, + ) + else: + with ThreadPoolExecutor(max_workers=self.num_threads) as executor: + results = executor.map( + lambda i: self.read_stack(band=i, rows=rows, cols=cols), bands + ) + data = np.stack(list(results), axis=0) + return data.squeeze() @property @@ -343,3 +959,81 @@ def _parse_vrt_file(vrt_file): filepaths.append(name) return filepaths, sds + + +class EagerLoader(BackgroundReader): + """Class to pre-fetch data chunks in a background thread.""" + + def __init__( + self, + reader: DatasetReader, + block_shape: tuple[int, int], + overlaps: tuple[int, int] = (0, 0), + skip_empty: bool = True, + nodata_value: Optional[float] = None, + nodata_mask: Optional[ArrayLike] = None, + queue_size: int = 1, + timeout: float = _DEFAULT_TIMEOUT, + show_progress: bool = True, + ): + super().__init__(nq=queue_size, timeout=timeout, name="EagerLoader") + self.reader = reader + # Set up the generator of ((row_start, row_end), (col_start, col_end)) + # convert the slice generator to a list so we have the size + nrows, ncols = self.reader.shape[-2:] + self.slices = list( + iter_blocks( + arr_shape=(nrows, ncols), + block_shape=block_shape, + overlaps=overlaps, + ) + ) + if nodata_value is None: + nodata_value = getattr(reader, "nodata", None) + self._queue_size = queue_size + self._skip_empty = skip_empty + self._nodata_mask = nodata_mask + self._block_shape = block_shape + self._nodata = nodata_value + self._show_progress = show_progress + if self._nodata is None: + self._nodata = np.nan + + def read(self, rows: slice, cols: slice) -> tuple[np.ndarray, tuple[slice, slice]]: + logger.debug(f"EagerLoader reading {rows}, {cols}") + cur_block = self.reader[..., rows, cols] + return cur_block, (rows, cols) + + def iter_blocks( + self, + ) -> Generator[tuple[np.ndarray, tuple[slice, slice]], None, None]: + # Queue up all slices to the work queue + queued_slices = [] + for rows, cols in self.slices: + # Skip queueing a read if all nodata + if self._skip_empty and self._nodata_mask is not None: + logger.debug("Checking nodata mask") + if self._nodata_mask[rows, cols].all(): + logger.debug("Skipping!") + continue + self.queue_read(rows, cols) + queued_slices.append((rows, cols)) + + s_iter = range(len(queued_slices)) + desc = f"Processing {self._block_shape} sized blocks..." + with progress(dummy=not self._show_progress) as p: + for _ in p.track(s_iter, description=desc): + cur_block, (rows, cols) = self.get_data() + logger.debug(f"got data for {rows, cols}: {cur_block.shape}") + + # Otherwise look at the actual block we loaded + if np.isnan(self._nodata): + block_is_nodata = np.isnan(cur_block) + else: + block_is_nodata = cur_block == self._nodata + if np.all(block_is_nodata): + logger.debug("Skipping block since it was all nodata") + continue + yield cur_block, (rows, cols) + + self.notify_finished() diff --git a/src/dolphin/io.py b/src/dolphin/io.py index 0efe386d..5ccbfc32 100644 --- a/src/dolphin/io.py +++ b/src/dolphin/io.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from os import fspath from pathlib import Path -from typing import Any, Generator, Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Optional, Sequence, Union import h5py import numpy as np @@ -19,11 +19,11 @@ from osgeo import gdal from pyproj import CRS -from dolphin._background import _DEFAULT_TIMEOUT, BackgroundReader, BackgroundWriter -from dolphin._blocks import compute_out_shape, iter_blocks +from dolphin._background import BackgroundWriter +from dolphin._blocks import compute_out_shape from dolphin._log import get_log from dolphin._types import Bbox, Filename -from dolphin.utils import gdal_to_numpy_type, numpy_to_gdal_type, progress +from dolphin.utils import gdal_to_numpy_type, numpy_to_gdal_type gdal.UseExceptions() @@ -31,7 +31,7 @@ "load_gdal", "write_arr", "write_block", - "EagerLoader", + "Writer", ] @@ -180,10 +180,11 @@ def format_nc_filename(filename: Filename, ds_name: Optional[str] = None) -> str If `ds_name` is not provided for a .h5 or .nc file. """ # If we've already formatted the filename, return it - if str(filename).startswith("NETCDF:") or str(filename).startswith("HDF5:"): - return str(filename) + fname_clean = fspath(filename).lstrip('"').lstrip("'").rstrip('"').rstrip("'") + if fname_clean.startswith("NETCDF:") or fname_clean.startswith("HDF5:"): + return fspath(filename) - if not (fspath(filename).endswith(".nc") or fspath(filename).endswith(".h5")): + if not (fname_clean.endswith(".nc") or fname_clean.endswith(".h5")): return fspath(filename) # Now we're definitely dealing with an HDF5/NetCDF file @@ -719,6 +720,20 @@ def from_user_inputs( ) +def get_raster_chunk_size(filename: Filename) -> list[int]: + """Get size the raster's chunks on disk. + + This is called blockXsize, blockYsize by GDAL. + """ + ds = gdal.Open(fspath(filename)) + block_size = ds.GetRasterBand(1).GetBlockSize() + for i in range(2, ds.RasterCount + 1): + if block_size != ds.GetRasterBand(i).GetBlockSize(): + logger.warning(f"Warning: {filename} bands have different block shapes.") + break + return block_size + + class Writer(BackgroundWriter): """Class to write data to files in a background thread.""" @@ -727,7 +742,7 @@ def __init__(self, max_queue: int = 0, debug: bool = False, **kwargs): super().__init__(nq=max_queue, name="Writer", **kwargs) else: # Don't start a background thread. Just synchronously write data - self.queue_write = lambda *args: write_block(*args) # type: ignore + setattr(self, "queue_write", self.write) def write( self, data: ArrayLike, filename: Filename, row_start: int, col_start: int @@ -756,168 +771,3 @@ def write( def num_queued(self): """Number of items waiting in the queue to be written.""" return self._work_queue.qsize() - - -class EagerLoader(BackgroundReader): - """Class to pre-fetch data chunks in a background thread.""" - - def __init__( - self, - filename: Filename, - block_shape: tuple[int, int], - overlaps: tuple[int, int] = (0, 0), - skip_empty: bool = True, - nodata_mask: Optional[ArrayLike] = None, - queue_size: int = 1, - timeout: float = _DEFAULT_TIMEOUT, - show_progress: bool = True, - ): - super().__init__(nq=queue_size, timeout=timeout, name="EagerLoader") - self.filename = filename - # Set up the generator of ((row_start, row_end), (col_start, col_end)) - xsize, ysize = get_raster_xysize(filename) - # convert the slice generator to a list so we have the size - self.slices = list( - iter_blocks( - arr_shape=(ysize, xsize), - block_shape=block_shape, - overlaps=overlaps, - ) - ) - self._queue_size = queue_size - self._skip_empty = skip_empty - self._nodata_mask = nodata_mask - self._block_shape = block_shape - self._nodata = get_raster_nodata(filename) - self._show_progress = show_progress - if self._nodata is None: - self._nodata = np.nan - - def read(self, rows: slice, cols: slice) -> tuple[np.ndarray, tuple[slice, slice]]: - logger.debug(f"EagerLoader reading {rows}, {cols}") - cur_block = load_gdal(self.filename, rows=rows, cols=cols) - return cur_block, (rows, cols) - - def iter_blocks( - self, - ) -> Generator[tuple[np.ndarray, tuple[slice, slice]], None, None]: - # Queue up all slices to the work queue - queued_slices = [] - for rows, cols in self.slices: - # Skip queueing a read if all nodata - if self._skip_empty and self._nodata_mask is not None: - logger.debug("Checking nodata mask") - if self._nodata_mask[rows, cols].all(): - logger.debug("Skipping!") - continue - self.queue_read(rows, cols) - queued_slices.append((rows, cols)) - - s_iter = range(len(queued_slices)) - desc = f"Processing {self._block_shape} sized blocks..." - with progress(dummy=not self._show_progress) as p: - for _ in p.track(s_iter, description=desc): - cur_block, (rows, cols) = self.get_data() - logger.debug(f"got data for {rows, cols}: {cur_block.shape}") - - # Otherwise look at the actual block we loaded - if np.isnan(self._nodata): - block_nodata = np.isnan(cur_block) - else: - block_nodata = cur_block == self._nodata - if np.all(block_nodata): - logger.debug("Skipping block since it was all nodata") - continue - yield cur_block, (rows, cols) - - self.notify_finished() - - -def get_max_block_shape( - filename: Filename, nstack: int, max_bytes: float = 64e6 -) -> tuple[int, int]: - """Find a block shape to load from `filename` with memory size < `max_bytes`. - - Attempts to get an integer number of chunks ("tiles" for geotiffs) from the - file to avoid partial tiles. - - Parameters - ---------- - filename : str - GDAL-readable file name containing 3D dataset. - nstack: int - Number of bands in dataset. - max_bytes : float, optional - Target size of memory (in Bytes) for each block. - Defaults to 64e6. - - Returns - ------- - tuple[int, int]: - (num_rows, num_cols) shape of blocks to load from `vrt_file` - """ - chunk_cols, chunk_rows = get_raster_chunk_size(filename) - xsize, ysize = get_raster_xysize(filename) - # If it's written by line, load at least 16 lines at a time - chunk_cols = min(max(16, chunk_cols), xsize) - chunk_rows = min(max(16, chunk_rows), ysize) - - ds = gdal.Open(fspath(filename)) - shape = (ds.RasterYSize, ds.RasterXSize) - # get the size of the data type from the raster - nbytes = gdal_to_numpy_type(ds.GetRasterBand(1).DataType).itemsize - return _increment_until_max( - max_bytes=max_bytes, - file_chunk_size=[chunk_rows, chunk_cols], - shape=shape, - nstack=nstack, - bytes_per_pixel=nbytes, - ) - - -def get_raster_chunk_size(filename: Filename) -> list[int]: - """Get size the raster's chunks on disk. - - This is called blockXsize, blockYsize by GDAL. - """ - ds = gdal.Open(fspath(filename)) - block_size = ds.GetRasterBand(1).GetBlockSize() - for i in range(2, ds.RasterCount + 1): - if block_size != ds.GetRasterBand(i).GetBlockSize(): - logger.warning(f"Warning: {filename} bands have different block shapes.") - break - return block_size - - -def _increment_until_max( - max_bytes: float, - file_chunk_size: Sequence[int], - shape: tuple[int, int], - nstack: int, - bytes_per_pixel: int = 8, -) -> tuple[int, int]: - """Find size of 3D chunk to load while staying at ~`max_bytes` bytes of RAM.""" - chunk_rows, chunk_cols = file_chunk_size - - # How many chunks can we fit in max_bytes? - chunks_per_block = max_bytes / ( - (nstack * chunk_rows * chunk_cols) * bytes_per_pixel - ) - num_chunks = [1, 1] - cur_block_shape = [chunk_rows, chunk_cols] - - idx = 1 # start incrementing cols - while chunks_per_block > 1 and tuple(cur_block_shape) != tuple(shape): - # Alternate between adding a row and column chunk by flipping the idx - chunk_idx = idx % 2 - nc = num_chunks[chunk_idx] - chunk_size = file_chunk_size[chunk_idx] - - cur_block_shape[chunk_idx] = min(nc * chunk_size, shape[chunk_idx]) - - chunks_per_block = max_bytes / ( - nstack * np.prod(cur_block_shape) * bytes_per_pixel - ) - num_chunks[chunk_idx] += 1 - idx += 1 - return cur_block_shape[0], cur_block_shape[1] diff --git a/src/dolphin/ps.py b/src/dolphin/ps.py index 7657de6b..fe6770f6 100644 --- a/src/dolphin/ps.py +++ b/src/dolphin/ps.py @@ -1,5 +1,4 @@ """Find the persistent scatterers in a stack of SLCS.""" - from __future__ import annotations import shutil @@ -14,7 +13,7 @@ from dolphin import io, utils from dolphin._log import get_log -from dolphin._readers import VRTStack +from dolphin._readers import EagerLoader, StackReader from dolphin._types import Filename gdal.UseExceptions() @@ -28,10 +27,11 @@ def create_ps( *, - slc_vrt_file: Filename, + reader: StackReader, output_file: Filename, output_amp_mean_file: Filename, output_amp_dispersion_file: Filename, + like_filename: Filename, amp_dispersion_threshold: float = 0.25, existing_amp_mean_file: Optional[Filename] = None, existing_amp_dispersion_file: Optional[Filename] = None, @@ -44,14 +44,16 @@ def create_ps( Parameters ---------- - slc_vrt_file : Filename - The VRT file pointing to the stack of SLCs. + reader : StackReader + A dataset reader for the 3D SLC stack. output_file : Filename The output PS file (dtype: Byte) output_amp_dispersion_file : Filename The output amplitude dispersion file. output_amp_mean_file : Filename The output mean amplitude file. + like_filename : Filename + The filename to use for the output files' spatial reference. amp_dispersion_threshold : float, optional The threshold for the amplitude dispersion. Default is 0.25. existing_amp_mean_file : Optional[Filename], optional @@ -91,29 +93,28 @@ def create_ps( for fn, dtype, nodata in zip(file_list, FILE_DTYPES, NODATA_VALUES): io.write_arr( arr=None, - like_filename=slc_vrt_file, + like_filename=like_filename, output_name=fn, nbands=1, dtype=dtype, nodata=nodata, ) - vrt_stack = VRTStack.from_vrt_file(slc_vrt_file) - # Initialize the intermediate arrays for the calculation - magnitude = np.zeros((len(vrt_stack), *block_shape), dtype=np.float32) + magnitude = np.zeros((reader.shape[0], *block_shape), dtype=np.float32) skip_empty = nodata_mask is None writer = io.Writer() # Make the generator for the blocks - block_gen = vrt_stack.iter_blocks( + block_gen = EagerLoader( + reader, block_shape=block_shape, - skip_empty=skip_empty, nodata_mask=nodata_mask, + skip_empty=skip_empty, show_progress=show_progress, ) - for cur_data, (rows, cols) in block_gen: + for cur_data, (rows, cols) in block_gen.iter_blocks(): cur_rows, cur_cols = cur_data.shape[-2:] if not (np.all(cur_data == 0) or np.all(np.isnan(cur_data))): diff --git a/src/dolphin/utils.py b/src/dolphin/utils.py index db2fa868..73db9034 100644 --- a/src/dolphin/utils.py +++ b/src/dolphin/utils.py @@ -115,15 +115,15 @@ def _get_path_from_gdal_str(name: Filename) -> Path: def _resolve_gdal_path(gdal_str: Filename) -> Filename: """Resolve the file portion of a gdal-openable string to an absolute path.""" - s = str(gdal_str) + s_clean = str(gdal_str).lstrip('"').lstrip("'").rstrip('"').rstrip("'") prefixes = ["DERIVED_SUBDATASET", "NETCDF", "HDF"] - is_gdal_str = any(s.upper().startswith(pre) for pre in prefixes) - file_part = str(_get_path_from_gdal_str(gdal_str)) + is_gdal_str = any(s_clean.upper().startswith(pre) for pre in prefixes) + file_part = str(_get_path_from_gdal_str(s_clean)) # strip quotes to add back in after file_part = file_part.strip('"').strip("'") file_part_resolved = Path(file_part).resolve() - resolved = s.replace(file_part, str(file_part_resolved)) + resolved = s_clean.replace(file_part, str(file_part_resolved)) return Path(resolved) if not is_gdal_str else resolved diff --git a/src/dolphin/workflows/ps.py b/src/dolphin/workflows/ps.py index 1e61ee6b..2ed694a5 100644 --- a/src/dolphin/workflows/ps.py +++ b/src/dolphin/workflows/ps.py @@ -88,10 +88,11 @@ def run( logger.info(f"Creating persistent scatterer file {ps_output}") dolphin.ps.create_ps( - slc_vrt_file=vrt_stack.outfile, + reader=vrt_stack, output_file=output_file_list[0], output_amp_mean_file=output_file_list[1], output_amp_dispersion_file=output_file_list[2], + like_filename=vrt_stack.outfile, amp_dispersion_threshold=cfg.ps_options.amp_dispersion_threshold, block_shape=cfg.worker_settings.block_shape, ) diff --git a/src/dolphin/workflows/wrapped_phase.py b/src/dolphin/workflows/wrapped_phase.py index 186cc39b..6a6c8744 100644 --- a/src/dolphin/workflows/wrapped_phase.py +++ b/src/dolphin/workflows/wrapped_phase.py @@ -74,7 +74,8 @@ def run( existing_amp = existing_disp = None ps.create_ps( - slc_vrt_file=vrt_stack.outfile, + reader=vrt_stack, + like_filename=vrt_stack.outfile, output_file=ps_output, output_amp_mean_file=cfg.ps_options._amp_mean_file, output_amp_dispersion_file=cfg.ps_options._amp_dispersion_file, diff --git a/tests/conftest.py b/tests/conftest.py index e5044201..64b25c7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -223,12 +223,12 @@ def raster_10_by_20(tmp_path, tiled_raster_100_by_200): @pytest.fixture -def raster_with_nan(tmpdir, tiled_raster_100_by_200): +def raster_with_nan(tmp_path, tiled_raster_100_by_200): # Raster with one nan pixel start_arr = load_gdal(tiled_raster_100_by_200) nan_arr = start_arr.copy() nan_arr[0, 0] = np.nan - output_name = tmpdir / "with_one_nan.tif" + output_name = tmp_path / "with_one_nan.tif" write_arr( arr=nan_arr, like_filename=tiled_raster_100_by_200, @@ -239,9 +239,9 @@ def raster_with_nan(tmpdir, tiled_raster_100_by_200): @pytest.fixture -def raster_with_nan_block(tmpdir, tiled_raster_100_by_200): +def raster_with_nan_block(tmp_path, tiled_raster_100_by_200): # One full block of 32x32 is nan - output_name = tmpdir / "with_nans.tif" + output_name = tmp_path / "with_nans.tif" nan_arr = load_gdal(tiled_raster_100_by_200) nan_arr[:32, :32] = np.nan write_arr( @@ -254,9 +254,9 @@ def raster_with_nan_block(tmpdir, tiled_raster_100_by_200): @pytest.fixture -def raster_with_zero_block(tmpdir, tiled_raster_100_by_200): +def raster_with_zero_block(tmp_path, tiled_raster_100_by_200): # One full block of 32x32 is nan - output_name = tmpdir / "with_zeros.tif" + output_name = tmp_path / "with_zeros.tif" out_arr = load_gdal(tiled_raster_100_by_200) out_arr[:] = 1.0 diff --git a/tests/test_io.py b/tests/test_io.py index da89e6fa..f36b2906 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -262,159 +262,6 @@ def test_get_raster_block_sizes(raster_100_by_200, tiled_raster_100_by_200): assert io.get_raster_chunk_size(raster_100_by_200) == [200, 1] -def test_get_max_block_shape(raster_100_by_200, tiled_raster_100_by_200): - # for io.get_max_block_shape, the rasters are 8 bytes per pixel - # if we have 1 GB, the whole raster should fit in memory - bs = io.get_max_block_shape(tiled_raster_100_by_200, nstack=1, max_bytes=1e9) - assert bs == (100, 200) - - # for untiled, the block size is one line - bs = io.get_max_block_shape(raster_100_by_200, nstack=1, max_bytes=0) - # The function forces at least 16 lines to be read at a time - assert bs == (16, 200) - bs = io.get_max_block_shape(raster_100_by_200, nstack=1, max_bytes=8 * 17 * 200) - assert bs == (32, 200) - - # Pretend we have a stack of 10 images - nstack = 10 - # one tile should be 8 * 32 * 32 * 10 = 81920 bytes - bytes_per_tile = 8 * 32 * 32 * nstack - bs = io.get_max_block_shape( - tiled_raster_100_by_200, nstack, max_bytes=bytes_per_tile - ) - assert bs == (32, 32) - - # with a little more, we should get 2 tiles - bs = io.get_max_block_shape( - tiled_raster_100_by_200, nstack, max_bytes=1 + bytes_per_tile - ) - assert bs == (32, 64) - - bs = io.get_max_block_shape( - tiled_raster_100_by_200, nstack, max_bytes=4 * bytes_per_tile - ) - assert bs == (64, 64) - - -def test_iter_blocks(tiled_raster_100_by_200): - # Try the whole raster - bs = io.get_max_block_shape(tiled_raster_100_by_200, 1, max_bytes=1e9) - loader = io.EagerLoader(filename=tiled_raster_100_by_200, block_shape=bs) - # `list` should try to load all at once` - block_slice_tuples = list(loader.iter_blocks()) - assert not loader._thread.is_alive() - assert len(block_slice_tuples) == 1 - blocks, slices = zip(*list(block_slice_tuples)) - assert blocks[0].shape == (100, 200) - rows, cols = slices[0] - assert rows == slice(0, 100) - assert cols == slice(0, 200) - - # now one block at a time - max_bytes = 8 * 32 * 32 - bs = io.get_max_block_shape(tiled_raster_100_by_200, 1, max_bytes=max_bytes) - loader = io.EagerLoader(filename=tiled_raster_100_by_200, block_shape=bs) - blocks, slices = zip(*list(loader.iter_blocks())) - - row_blocks = 100 // 32 + 1 - col_blocks = 200 // 32 + 1 - expected_num_blocks = row_blocks * col_blocks - assert len(blocks) == expected_num_blocks - assert blocks[0].shape == (32, 32) - # at the ends, the block_slice_tuples are smaller - assert blocks[6].shape == (32, 8) - assert blocks[-1].shape == (4, 8) - - -def test_iter_blocks_rowcols(tiled_raster_100_by_200): - # Block size that is a multiple of the raster size - loader = io.EagerLoader(filename=tiled_raster_100_by_200, block_shape=(10, 20)) - blocks, slices = zip(*list(loader.iter_blocks())) - - assert blocks[0].shape == (10, 20) - for rs, cs in slices: - assert rs.stop - rs.start == 10 - assert cs.stop - cs.start == 20 - loader.notify_finished() - - # Non-multiple block size - loader = io.EagerLoader(filename=tiled_raster_100_by_200, block_shape=(32, 32)) - blocks, slices = zip(*list(loader.iter_blocks())) - assert blocks[0].shape == (32, 32) - for b, (rs, cs) in zip(blocks, slices): - assert b.shape == (rs.stop - rs.start, cs.stop - cs.start) - loader.notify_finished() - - -def test_iter_nodata( - raster_with_nan, - raster_with_nan_block, - raster_with_zero_block, - tiled_raster_100_by_200, -): - # load one block at a time - max_bytes = 8 * 32 * 32 - bs = io.get_max_block_shape(tiled_raster_100_by_200, 1, max_bytes=max_bytes) - loader = io.EagerLoader(filename=tiled_raster_100_by_200, block_shape=bs) - blocks, slices = zip(*list(loader.iter_blocks())) - loader.notify_finished() - - row_blocks = 100 // 32 + 1 - col_blocks = 200 // 32 + 1 - expected_num_blocks = row_blocks * col_blocks - assert len(blocks) == expected_num_blocks - assert blocks[0].shape == (32, 32) - - # One nan should be fine, will get loaded - loader = io.EagerLoader(filename=raster_with_nan, block_shape=bs) - blocks, slices = zip(*list(loader.iter_blocks())) - loader.notify_finished() - assert len(blocks) == expected_num_blocks - - # Now check entire block for a skipped block - loader = io.EagerLoader(filename=raster_with_nan_block, block_shape=bs) - blocks, slices = zip(*list(loader.iter_blocks())) - loader.notify_finished() - assert len(blocks) == expected_num_blocks - 1 - - # Now check entire block for a skipped block - loader = io.EagerLoader(filename=raster_with_zero_block, block_shape=bs) - blocks, slices = zip(*list(loader.iter_blocks())) - loader.notify_finished() - assert len(blocks) == expected_num_blocks - 1 - - -@pytest.mark.skip -def test_iter_blocks_nodata_mask(tiled_raster_100_by_200): - # load one block at a time - max_bytes = 8 * 32 * 32 - bs = io.get_max_block_shape(tiled_raster_100_by_200, 1, max_bytes=max_bytes) - blocks = list(io.iter_blocks(tiled_raster_100_by_200, bs, band=1)) - row_blocks = 100 // 32 + 1 - col_blocks = 200 // 32 + 1 - expected_num_blocks = row_blocks * col_blocks - assert len(blocks) == expected_num_blocks - - nodata_mask = np.zeros((100, 200), dtype=bool) - nodata_mask[:5, :5] = True - # non-full-block should still all be loaded nan should be fine, will get loaded - blocks = list( - io.iter_blocks( - tiled_raster_100_by_200, bs, skip_empty=True, nodata_mask=nodata_mask - ) - ) - assert len(blocks) == expected_num_blocks - - nodata_mask[:32, :32] = True - # non-full-block should still all be loaded nan should be fine, will get loaded - blocks = list( - io.iter_blocks( - tiled_raster_100_by_200, bs, skip_empty=True, nodata_mask=nodata_mask - ) - ) - assert len(blocks) == expected_num_blocks - 1 - - def test_format_nc_filename(): expected = 'NETCDF:"/usr/19990101/20200303_20210101.nc":"//variable"' assert ( diff --git a/tests/test_ps.py b/tests/test_ps.py index 91920f8c..a7d17e0e 100644 --- a/tests/test_ps.py +++ b/tests/test_ps.py @@ -61,7 +61,8 @@ def test_create_ps(tmp_path, vrt_stack): amp_dispersion_file = tmp_path / "amp_disp.tif" amp_mean_file = tmp_path / "amp_mean.tif" dolphin.ps.create_ps( - slc_vrt_file=vrt_stack.outfile, + reader=vrt_stack, + like_filename=vrt_stack.outfile, output_amp_dispersion_file=amp_dispersion_file, output_amp_mean_file=amp_mean_file, output_file=ps_mask_file, @@ -92,12 +93,12 @@ def test_multilook_ps_file(tmp_path, vrt_stack): amp_dispersion_file = tmp_path / "amp_disp.tif" amp_mean_file = tmp_path / "amp_mean.tif" dolphin.ps.create_ps( - slc_vrt_file=vrt_stack.outfile, + reader=vrt_stack, + like_filename=vrt_stack.outfile, output_amp_dispersion_file=amp_dispersion_file, output_amp_mean_file=amp_mean_file, output_file=ps_mask_file, ) - output_file = dolphin.ps.multilook_ps_mask( strides={"x": 5, "y": 3}, ps_mask_file=ps_mask_file ) diff --git a/tests/test_readers.py b/tests/test_readers.py index 816939de..1e192d0b 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -1,15 +1,242 @@ +import warnings from pathlib import Path import numpy as np import numpy.testing as npt import pytest +import rasterio as rio from osgeo import gdal - -from dolphin._readers import VRTStack, _parse_vrt_file +from rasterio.errors import NotGeoreferencedWarning + +from dolphin._readers import ( + BinaryReader, + BinaryStackReader, + EagerLoader, + HDF5Reader, + HDF5StackReader, + RasterReader, + RasterStackReader, + VRTStack, + _parse_vrt_file, +) from dolphin.utils import _get_path_from_gdal_str # Note: uses the fixtures from conftest.py +# Get combinations of slices +slices_to_test = [slice(None), 1, slice(0, 10, 2)] + + +# Filter rasterio georeferencing warnings +@pytest.fixture(autouse=True) +def suppress_not_georeferenced_warning(): + """ + Pytest fixture to suppress NotGeoreferencedWarning in tests. + + This fixture automatically applies to all test functions in the module + where it's defined, suppressing the specified warning. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=NotGeoreferencedWarning) + yield + + +@pytest.fixture(scope="module") +def binary_file_list(tmp_path_factory, slc_stack): + """Flat binary files in the ENVI format.""" + import rasterio as rio + from rasterio.errors import NotGeoreferencedWarning + + shape = slc_stack[0].shape + dtype = slc_stack.dtype + tmp_path = tmp_path_factory.mktemp("data") + + # Create a stack of binary files + files = [] + for i, slc in enumerate(slc_stack): + f = tmp_path / f"test_{i}.bin" + # Ignore warning + with pytest.warns(NotGeoreferencedWarning): + with rio.open( + f, + "w", + driver="ENVI", + width=shape[1], + height=shape[0], + count=1, + dtype=dtype, + ) as dst: + dst.write(slc, 1) + files.append(f) + + return files + + +@pytest.fixture +def binary_reader(slc_stack, binary_file_list): + f = BinaryReader( + binary_file_list[0], shape=slc_stack[0].shape, dtype=slc_stack.dtype + ) + assert f.shape == slc_stack[0].shape + assert f.dtype == slc_stack[0].dtype + return f + + +class TestBinary: + def test_binary_file_read(self, binary_reader, slc_stack): + npt.assert_array_almost_equal(binary_reader[()], slc_stack[0]) + # Check the reading of a subset + npt.assert_array_almost_equal( + binary_reader[0:10, 0:10], slc_stack[0][0:10, 0:10] + ) + + @pytest.fixture(scope="module") + def binary_stack(self, slc_stack, binary_file_list): + s = BinaryStackReader.from_file_list( + binary_file_list, shape_2d=slc_stack[0].shape, dtype=slc_stack.dtype + ) + assert s.shape == slc_stack.shape + assert len(s) == len(slc_stack) == len(binary_file_list) + assert s.ndim == 3 + assert s.dtype == slc_stack.dtype + return s + + @pytest.mark.parametrize("dslice", slices_to_test) + @pytest.mark.parametrize("rslice", slices_to_test) + @pytest.mark.parametrize("cslice", slices_to_test) + def test_binary_stack_read_slices( + self, binary_stack, slc_stack, dslice, rslice, cslice + ): + s = binary_stack[dslice, rslice, cslice] + expected = slc_stack[dslice, rslice, cslice] + assert s.shape == expected.shape + npt.assert_array_almost_equal(s, expected) + + +# #### HDF5 Tests #### + + +@pytest.fixture(scope="module") +def hdf5_file_list(tmp_path_factory, slc_stack): + """Flat binary files in the ENVI format.""" + import h5py + + tmp_path = tmp_path_factory.mktemp("data") + + # Create a stack of binary files + files = [] + for i, slc in enumerate(slc_stack): + f = tmp_path / f"test_{i}.h5" + with h5py.File(f, "w") as dst: + dst.create_dataset("data", data=slc) + files.append(f) + + return files + + +@pytest.fixture +def hdf5_reader(hdf5_file_list, slc_stack): + r = HDF5Reader(hdf5_file_list[0], dset_name="data", keep_open=True) + assert r.shape == slc_stack[0].shape + assert r.dtype == slc_stack[0].dtype + return r + + +class TestHDF5: + def test_hdf5_reader_read(self, hdf5_reader, slc_stack): + npt.assert_array_almost_equal(hdf5_reader[()], slc_stack[0]) + # Check the reading of a subset + npt.assert_array_almost_equal(hdf5_reader[0:10, 0:10], slc_stack[0][0:10, 0:10]) + + @pytest.mark.parametrize("keep_open", [True, False]) + def hdf5_stack(self, hdf5_file_list, slc_stack, keep_open): + s = HDF5StackReader.from_file_list( + hdf5_file_list, dset_names="data", keep_open=keep_open + ) + assert s.shape == slc_stack.shape + assert len(s) == len(slc_stack) == len(hdf5_file_list) + assert s.ndim == 3 + assert s.dtype == slc_stack.dtype + return s + + @pytest.mark.parametrize("dslice", slices_to_test) + @pytest.mark.parametrize("rslice", slices_to_test) + @pytest.mark.parametrize("cslice", slices_to_test) + @pytest.mark.parametrize("keep_open", [True, False]) + def test_hdf5_stack_read_slices( + self, hdf5_file_list, slc_stack, keep_open, dslice, rslice, cslice + ): + reader = HDF5StackReader.from_file_list( + hdf5_file_list, dset_names="data", keep_open=keep_open + ) + s = reader[dslice, rslice, cslice] + expected = slc_stack[dslice, rslice, cslice] + assert s.shape == expected.shape + npt.assert_array_almost_equal(s, expected) + + +# #### RasterReader Tests #### +@pytest.fixture +def raster_reader(slc_file_list, slc_stack): + # ignore georeferencing warnings + with pytest.warns(rio.errors.NotGeoreferencedWarning): + r = RasterReader.from_file(slc_file_list[0]) + assert r.shape == slc_stack[0].shape + assert r.dtype == slc_stack[0].dtype + assert r.ndim == 2 + assert r.dtype == np.complex64 + return r + + +class TestRaster: + @pytest.mark.parametrize("keep_open", [True, False]) + def test_raster_stack_reader( + self, + slc_file_list, + slc_stack, + keep_open, + ): + with pytest.warns(rio.errors.NotGeoreferencedWarning): + reader = RasterStackReader.from_file_list( + slc_file_list, keep_open=keep_open + ) + assert reader.ndim == 3 + assert reader.shape == slc_stack.shape + assert reader.dtype == slc_stack.dtype + assert len(reader) == len(slc_stack) == len(slc_file_list) + + @pytest.mark.parametrize("dslice", slices_to_test) + @pytest.mark.parametrize("rslice", slices_to_test) + @pytest.mark.parametrize("cslice", slices_to_test) + @pytest.mark.parametrize("keep_open", [True, False]) + def test_raster_stack_read_slices( + self, slc_file_list, slc_stack, keep_open, dslice, rslice, cslice + ): + with pytest.warns(rio.errors.NotGeoreferencedWarning): + reader = RasterStackReader.from_file_list( + slc_file_list, keep_open=keep_open + ) + s = reader[dslice, rslice, cslice] + expected = slc_stack[dslice, rslice, cslice] + assert s.shape == expected.shape + npt.assert_array_almost_equal(s, expected) + + +@pytest.mark.parametrize("rows", slices_to_test) +@pytest.mark.parametrize("cols", slices_to_test) +def test_ellipsis_reads(binary_reader, hdf5_reader, raster_reader, rows, cols): + # Test that the ellipsis works + npt.assert_array_equal(binary_reader[...], binary_reader[()]) + npt.assert_array_equal(hdf5_reader[...], hdf5_reader[()]) + npt.assert_array_equal(raster_reader[...], raster_reader[()]) + # Test we can still do rows/cols with a leading ellipsis + npt.assert_array_equal(binary_reader[..., rows, cols], binary_reader[rows, cols]) + npt.assert_array_equal(hdf5_reader[..., rows, cols], hdf5_reader[rows, cols]) + npt.assert_array_equal(raster_reader[..., rows, cols], raster_reader[rows, cols]) + + +# #### VRT Tests #### + @pytest.fixture def vrt_stack(tmp_path, slc_stack, slc_file_list): @@ -128,22 +355,25 @@ def test_bad_sizes(slc_file_list, raster_10_by_20): def test_iter_blocks(vrt_stack): - blocks, slices = zip(*list(vrt_stack.iter_blocks(block_shape=(5, 5)))) + loader = EagerLoader(reader=vrt_stack, block_shape=(5, 5)) + blocks, slices = zip(*list(loader.iter_blocks())) # (5, 10) total shape, breaks into 5x5 blocks assert len(blocks) == 2 for b in blocks: assert b.shape == (len(vrt_stack), 5, 5) - blocks, slices = zip(*list(vrt_stack.iter_blocks(block_shape=(1, 2)))) - assert len(blocks) == 25 + loader = EagerLoader(reader=vrt_stack, block_shape=(5, 2)) + blocks, slices = zip(*list(loader.iter_blocks())) + assert len(blocks) == 5 for b in blocks: - assert b.shape == (len(vrt_stack), 1, 2) + assert b.shape == (len(vrt_stack), 5, 2) def test_tiled_iter_blocks(tmp_path, tiled_file_list): outfile = tmp_path / "stack.vrt" vrt_stack = VRTStack(tiled_file_list, outfile=outfile) - blocks, slices = zip(*list(vrt_stack.iter_blocks(block_shape=(32, 32)))) + loader = EagerLoader(reader=vrt_stack, block_shape=(32, 32)) + blocks, slices = zip(*list(loader.iter_blocks())) # (100, 200) total shape, breaks into 32x32 blocks assert len(blocks) == len(slices) == 28 for i, b in enumerate(blocks, start=1): @@ -155,7 +385,8 @@ def test_tiled_iter_blocks(tmp_path, tiled_file_list): else: assert b.shape == (len(vrt_stack), 32, 8) - blocks, slices = zip(*list(vrt_stack.iter_blocks(block_shape=(50, 100)))) + loader = EagerLoader(reader=vrt_stack, block_shape=(50, 100)) + blocks, slices = zip(*list(loader.iter_blocks())) assert len(blocks) == len(slices) == 4 @@ -212,3 +443,101 @@ def test_parse_vrt(tmp_path, test_vrt): "t087_185684_iw2_20220104.h5", ] assert sds == "data/VV" + + +class TestEagerLoader: + def test_iter_blocks(self, tiled_raster_100_by_200): + # Try the whole raster + bs = (100, 200) + loader = EagerLoader( + reader=RasterReader.from_file(tiled_raster_100_by_200), block_shape=bs + ) + # `list` should try to load all at once` + block_slice_tuples = list(loader.iter_blocks()) + assert not loader._thread.is_alive() + assert len(block_slice_tuples) == 1 + blocks, slices = zip(*list(block_slice_tuples)) + assert blocks[0].shape == (100, 200) + rows, cols = slices[0] + assert rows == slice(0, 100) + assert cols == slice(0, 200) + + # now one block at a time + bs = (32, 32) + reader = RasterReader.from_file(tiled_raster_100_by_200) + loader = EagerLoader(reader=reader, block_shape=bs) + blocks, slices = zip(*list(loader.iter_blocks())) + + row_blocks = 100 // 32 + 1 + col_blocks = 200 // 32 + 1 + expected_num_blocks = row_blocks * col_blocks + assert len(blocks) == expected_num_blocks + assert blocks[0].shape == (32, 32) + # at the ends, the block_slice_tuples are smaller + assert blocks[6].shape == (32, 8) + assert blocks[-1].shape == (4, 8) + + def test_iter_blocks_rowcols(self, tiled_raster_100_by_200): + # Block size that is a multiple of the raster size + reader = RasterReader.from_file(tiled_raster_100_by_200) + loader = EagerLoader(reader=reader, block_shape=(10, 20)) + blocks, slices = zip(*list(loader.iter_blocks())) + + assert blocks[0].shape == (10, 20) + for rs, cs in slices: + assert rs.stop - rs.start == 10 + assert cs.stop - cs.start == 20 + loader.notify_finished() + + # Non-multiple block size + reader = RasterReader.from_file(tiled_raster_100_by_200) + loader = EagerLoader(reader=reader, block_shape=(32, 32)) + blocks, slices = zip(*list(loader.iter_blocks())) + assert blocks[0].shape == (32, 32) + for b, (rs, cs) in zip(blocks, slices): + assert b.shape == (rs.stop - rs.start, cs.stop - cs.start) + loader.notify_finished() + + def test_iter_nodata( + self, + raster_with_nan, + raster_with_nan_block, + raster_with_zero_block, + tiled_raster_100_by_200, + ): + # load one block at a time + bs = (100, 200) + reader = RasterReader.from_file(tiled_raster_100_by_200) + loader = EagerLoader(reader=reader, block_shape=bs) + blocks, slices = zip(*list(loader.iter_blocks())) + loader.notify_finished() + + bs = (32, 32) + row_blocks = 100 // 32 + 1 + col_blocks = 200 // 32 + 1 + expected_num_blocks = row_blocks * col_blocks + loader = EagerLoader(reader=reader, block_shape=bs) + blocks, slices = zip(*list(loader.iter_blocks())) + assert len(blocks) == expected_num_blocks + assert blocks[0].shape == bs + + # One nan should be fine, will get loaded + reader = RasterReader.from_file(raster_with_nan) + loader = EagerLoader(reader=reader, block_shape=bs) + blocks, slices = zip(*list(loader.iter_blocks())) + loader.notify_finished() + assert len(blocks) == expected_num_blocks + + # Now check entire block for a skipped block + reader = RasterReader.from_file(raster_with_nan_block) + loader = EagerLoader(reader=reader, block_shape=bs) + blocks, slices = zip(*list(loader.iter_blocks())) + loader.notify_finished() + assert len(blocks) == expected_num_blocks - 1 + + # Now check entire block for a skipped block + reader = RasterReader.from_file(raster_with_zero_block) + loader = EagerLoader(reader=reader, block_shape=bs) + blocks, slices = zip(*list(loader.iter_blocks())) + loader.notify_finished() + assert len(blocks) == expected_num_blocks - 1