Skip to content

Commit

Permalink
Adding option to create overviews on workflow outputs (#220)
Browse files Browse the repository at this point in the history
* start adding option to create overviews on outputs

* add overview script and cli

* remove debug print

* add enums for default overview types

* reorg test unwrap, make the failing test for conncomp

* fix unw test

* add ovr tests

* default ovr True
  • Loading branch information
scottstanie authored Feb 8, 2024
1 parent 3b5c359 commit 9bb183d
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 98 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ ignore = [
"PTH123", # `open()` should be replaced by `Path.open()`
"PTH207", # "Replace `glob` with `Path.glob` or `Path.rglob`
"ISC001", # The following rules may cause conflicts when used with the formatter
"TRY003", # Avoid specifying long messages outside the exception
]

exclude = ["scripts"]
Expand Down
175 changes: 175 additions & 0 deletions src/dolphin/_overviews.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import argparse
import logging
from enum import Enum
from os import fspath
from pathlib import Path
from typing import Sequence

from osgeo import gdal
from tqdm.contrib.concurrent import thread_map

from dolphin._types import PathOrStr

gdal.UseExceptions()

logger = logging.getLogger(__name__)

DEFAULT_LEVELS = [4, 8, 16, 32, 64]


class Resampling(Enum):
"""GDAL resampling algorithm."""

NEAREST = "nearest"
AVERAGE = "average"


class ImageType(Enum):
"""Types of images produced by dolphin."""

UNWRAPPED = "unwrapped"
INTERFEROGRAM = "interferogram"
CORRELATION = "correlation"
CONNCOMP = "conncomp"
PS = "ps"


IMAGE_TYPE_TO_RESAMPLING = {
ImageType.UNWRAPPED: Resampling.AVERAGE,
ImageType.INTERFEROGRAM: Resampling.AVERAGE,
ImageType.CORRELATION: Resampling.AVERAGE,
ImageType.CONNCOMP: Resampling.NEAREST,
# No max in resampling, yet, which would be best
# https://github.com/OSGeo/gdal/issues/3683
ImageType.PS: Resampling.AVERAGE,
}


def create_image_overviews(
file_path: Path | str,
levels: Sequence[int] = DEFAULT_LEVELS,
image_type: ImageType | None = None,
resampling: Resampling | None = None,
external: bool = False,
compression: str = "LZW",
):
"""Add GDAL compressed overviews to an existing file.
Parameters
----------
file_path : Path
Path to the file to process.
levels : Sequence[int]
List of overview levels to add.
Default = [4, 8, 16, 32, 64]
image_type : ImageType, optional
If provided, looks up the default resampling algorithm
most appropriate for this type of raster.
resampling : str or Resampling
GDAL resampling algorithm for overviews. Required
if not specifying `image_type`.
external : bool, default = False
Use external overviews (.ovr files).
compression: str, default = "LZW"
Compression algorithm to use for overviews.
See https://gdal.org/programs/gdaladdo.html for options.
"""
if image_type is None and resampling is None:
raise ValueError("Must provide `image_type` or `resampling`")
if image_type is not None:
resampling = IMAGE_TYPE_TO_RESAMPLING[ImageType(image_type)]
else:
resampling = Resampling(resampling)

flags = gdal.GA_Update if not external else gdal.GA_ReadOnly
ds = gdal.Open(fspath(file_path), flags)
if ds.GetRasterBand(1).GetOverviewCount() > 0:
logger.info("%s already has overviews. Skipping.")
return

gdal.SetConfigOption("COMPRESS_OVERVIEW", compression)
gdal.SetConfigOption("GDAL_NUM_THREADS", "2")
ds.BuildOverviews(resampling.value, levels)


def create_overviews(
file_paths: Sequence[PathOrStr],
levels: Sequence[int] = DEFAULT_LEVELS,
image_type: ImageType | None = None,
resampling: Resampling = Resampling.AVERAGE,
max_workers: int = 5,
) -> None:
"""Process many files to add GDAL overviews and compression.
Parameters
----------
file_paths : Sequence[PathOrStr]
Sequence of file paths to process.
levels : Sequence[int]
Sequence of overview levels to add.
Default = [4, 8, 16, 32, 64]
image_type : ImageType, optional
If provided, looks up the default resampling algorithm
resampling : str or Resampling
GDAL resampling algorithm for overviews. Required
if not specifying `image_type`.
max_workers : int, default = 5
Number of parallel threads to run.
"""
thread_map(
lambda file_path: create_image_overviews(
Path(file_path),
levels=list(levels),
image_type=image_type,
resampling=resampling,
),
file_paths,
max_workers=max_workers,
)


def run():
"""Add compressed GDAL overviews to files."""
parser = argparse.ArgumentParser(
description="Add compressed GDAL overviews to files.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("file_paths", nargs="+", type=str, help="Path to files")
parser.add_argument(
"--levels",
"-l",
nargs="*",
default=[4, 8, 16, 32, 64],
type=int,
help="Overview levels to add.",
)
parser.add_argument(
"--resampling",
"-r",
default=Resampling("nearest"),
choices=[r.value for r in Resampling],
type=Resampling,
help="Resampling algorithm to use when building overviews",
)
parser.add_argument(
"--max-workers",
"-n",
default=5,
type=int,
help="Number of parallel files to process",
)

args = parser.parse_args()

# Convert resampling argument from string to Resampling Enum
resampling_enum = Resampling(args.resampling)

create_overviews(
file_paths=args.file_paths,
levels=args.levels,
resampling=resampling_enum,
max_workers=args.max_workers,
overwrite=args.overwrite,
)
2 changes: 1 addition & 1 deletion src/dolphin/io/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def iter_blocks(
self.queue_read(rows, cols)
queued_slices.append((rows, cols))

logger.info(f"Processing {self._block_shape} sized blocks... {tqdm_kwargs}")
logger.info(f"Processing {self._block_shape} sized blocks...")
for _ in trange(len(queued_slices), **tqdm_kwargs):
cur_block, (rows, cols) = self.get_data()
logger.debug(f"got data for {rows, cols}: {cur_block.shape}")
Expand Down
7 changes: 5 additions & 2 deletions src/dolphin/unwrap/_unwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def run(

output_path = Path(output_path)

ifg_suffixes = [full_suffix(f) for f in ifg_filenames]
all_out_files = [
(output_path / Path(f).name).with_suffix(UNW_SUFFIX) for f in ifg_filenames
(output_path / Path(f).name.replace(suf, UNW_SUFFIX))
for f, suf in zip(ifg_filenames, ifg_suffixes)
]
in_files, out_files = [], []
for inf, outf in zip(ifg_filenames, all_out_files):
Expand Down Expand Up @@ -137,7 +139,8 @@ def run(
for ifg_file, out_file, cor_file in zip(in_files, out_files, cor_filenames)
]
for fut in tqdm(as_completed(futures)):
fut.result()
# We're not passing all the unw files in, so we need to tally up below
_unw_path, _cc_path = fut.result()

conncomp_files = [
Path(str(outf).replace(UNW_SUFFIX, CONNCOMP_SUFFIX)) for outf in all_out_files
Expand Down
12 changes: 12 additions & 0 deletions src/dolphin/workflows/config/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ class OutputOptions(BaseModel, extra="forbid"):
list(DEFAULT_TIFF_OPTIONS),
description="GDAL creation options for GeoTIFF files",
)
add_overviews: bool = Field(
True,
description=(
"Whether to add overviews to the output GeoTIFF files. This will "
"increase file size, but can be useful for visualizing the data with "
"web mapping tools. See https://gdal.org/programs/gdaladdo.html for more."
),
)
overview_levels: list[int] = Field(
[4, 8, 16, 32, 64],
description="List of overview levels to create (if `add_overviews=True`).",
)

# validators
@field_validator("bounds", mode="after")
Expand Down
8 changes: 8 additions & 0 deletions src/dolphin/workflows/stitching_bursts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dolphin import stitching
from dolphin._log import get_log, log_runtime
from dolphin._overviews import ImageType, create_image_overviews, create_overviews
from dolphin._types import Bbox
from dolphin.interferogram import estimate_interferometric_correlations

Expand Down Expand Up @@ -100,6 +101,13 @@ def run(
out_bounds_epsg=output_options.bounds_epsg,
)

if output_options.add_overviews:
logger.info("Creating overviews for stitched images")
create_overviews(stitched_ifg_paths, image_type=ImageType.INTERFEROGRAM)
create_overviews(interferometric_corr_paths, image_type=ImageType.CORRELATION)
create_image_overviews(stitched_ps_file, image_type=ImageType.PS)
create_image_overviews(stitched_temp_coh_file, image_type=ImageType.CORRELATION)

return (
stitched_ifg_paths,
interferometric_corr_paths,
Expand Down
11 changes: 11 additions & 0 deletions src/dolphin/workflows/unwrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dolphin import io, stitching, unwrap
from dolphin._log import get_log, log_runtime
from dolphin._overviews import ImageType, create_overviews
from dolphin._types import PathOrStr

from .config import UnwrapOptions
Expand All @@ -20,6 +21,7 @@ def run(
nlooks: float,
unwrap_options: UnwrapOptions,
mask_file: PathOrStr | None = None,
add_overviews: bool = True,
) -> tuple[list[Path], list[Path]]:
"""Run the displacement workflow on a stack of SLCs.
Expand All @@ -37,6 +39,9 @@ def run(
mask_file : PathOrStr, optional
Path to boolean mask indicating nodata areas.
1 indicates valid data, 0 indicates missing data.
add_overviews : bool, default = True
If True, creates overviews of the unwrapped phase and connected component
labels.
Returns
-------
Expand Down Expand Up @@ -80,6 +85,12 @@ def run(
scratchdir=unwrap_scratchdir,
)

if add_overviews:
logger.info("Creating overviews for unwrapped images")
create_overviews(unwrapped_paths, image_type=ImageType.UNWRAPPED)
create_overviews(conncomp_paths, image_type=ImageType.CONNCOMP)
create_overviews(unwrapped_paths, image_type=ImageType.CORRELATION)

return (unwrapped_paths, conncomp_paths)


Expand Down
36 changes: 36 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,39 @@ def dem_file(tmp_path, slc_stack):
ds.GetRasterBand(1).WriteArray(dem)
ds = None
return fname


# For unwrapping/overviews
@pytest.fixture()
def list_of_gtiff_ifgs(tmp_path, raster_100_by_200):
ifg_list = []
for i in range(3):
# Create a copy of the raster in the same directory
f = tmp_path / f"ifg_{i}.int.tif"
write_arr(
arr=np.ones((100, 200), dtype=np.complex64),
output_name=f,
like_filename=raster_100_by_200,
driver="GTiff",
)
ifg_list.append(f)

return ifg_list


@pytest.fixture()
def list_of_envi_ifgs(tmp_path, raster_100_by_200):
ifg_list = []
for i in range(3):
# Create a copy of the raster in the same directory
f = tmp_path / f"ifg_{i}.int"
ifg_list.append(f)
write_arr(
arr=np.ones((100, 200), dtype=np.complex64),
output_name=f,
like_filename=raster_100_by_200,
driver="ENVI",
)
ifg_list.append(f)

return ifg_list
55 changes: 55 additions & 0 deletions tests/test_overviews.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import rasterio as rio

from dolphin._overviews import (
ImageType,
Resampling,
create_image_overviews,
create_overviews,
)

# Dataset has no geotransform, gcps, or rpcs. The identity matrix will be returned.
pytestmark = pytest.mark.filterwarnings(
"ignore::rasterio.errors.NotGeoreferencedWarning",
"ignore:.*io.FileIO.*:pytest.PytestUnraisableExceptionWarning",
)


def get_overviews(filename, band=1):
with rio.open(filename) as src:
return src.overviews(band)


def test_create_image_overviews(list_of_gtiff_ifgs):
f = list_of_gtiff_ifgs[0]
assert len(get_overviews(f)) == 0
create_image_overviews(f, image_type=ImageType.INTERFEROGRAM)
assert len(get_overviews(f)) > 0


def test_create_image_overviews_envi(list_of_envi_ifgs):
f = list_of_envi_ifgs[0]
assert len(get_overviews(f)) == 0
create_image_overviews(f, image_type=ImageType.INTERFEROGRAM)
assert len(get_overviews(f)) > 0


@pytest.mark.parametrize("resampling", list(Resampling))
def test_resamplings(list_of_gtiff_ifgs, resampling):
f = list_of_gtiff_ifgs[0]
assert len(get_overviews(f)) == 0
create_image_overviews(f, resampling=resampling)
assert len(get_overviews(f)) > 0


@pytest.mark.parametrize("levels", [[2, 4], [4, 8, 18]])
def test_levels(list_of_gtiff_ifgs, levels):
f = list_of_gtiff_ifgs[0]
assert len(get_overviews(f)) == 0
create_image_overviews(f, resampling="nearest", levels=levels)
assert len(get_overviews(f)) == len(levels)


def test_create_overviews(list_of_gtiff_ifgs):
create_overviews(list_of_gtiff_ifgs, image_type=ImageType.INTERFEROGRAM)
assert all(len(get_overviews(f)) > 0 for f in list_of_gtiff_ifgs)
Loading

0 comments on commit 9bb183d

Please sign in to comment.