Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dolphin filter cli to accept fill_value and threshold #510

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions src/dolphin/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import numpy as np
from numpy.typing import ArrayLike, NDArray
from scipy import fft, ndimage
from scipy.ndimage import gaussian_filter
from tqdm import tqdm

from dolphin import io
from dolphin._overviews import Resampling, create_image_overviews
from dolphin.utils import DummyProcessPoolExecutor


def filter_long_wavelength(
Expand Down Expand Up @@ -157,9 +163,11 @@ def filter_rasters(
cor_filenames: list[Path] | None = None,
conncomp_filenames: list[Path] | None = None,
temporal_coherence_filename: Path | None = None,
output_dir: Path | None = None,
wavelength_cutoff: float = 50_000,
fill_value: float | None = None,
correlation_cutoff: float = 0.5,
output_dir: Path | None = None,
temporal_coherence_cutoff: float = 0.5,
max_workers: int = 4,
) -> list[Path]:
"""Filter a list of unwrapped interferogram files using a long-wavelength filter.
Expand All @@ -186,6 +194,13 @@ def filter_rasters(
correlation_cutoff : float, optional
Threshold of correlation (if passing `cor_filenames`) to use to ignore pixels
during filtering.
temporal_coherence_cutoff : float, optional
Threshold of temporal_coherence to use to ignore pixels during filtering.
Default is 0.5.
fill_value : float, optional
Value to place in output pixels which were masked.
If `None`, masked pixels are filled with the ramp value fitted
before filtering to suppress outliers.
output_dir : Path | None, optional
Directory to save the filtered results.
If None, saves in the same location as inputs with .filt.tif extension.
Expand All @@ -197,30 +212,31 @@ def filter_rasters(
list[Path]
Output filtered rasters.

Notes
-----
- If temporal_coherence_filename is provided, pixels with coherence < 0.5 are masked

"""
from dolphin import io

bad_pixel_mask = np.zeros(
io.get_raster_xysize(unw_filenames[0])[::-1], dtype="bool"
)
bad_pixel_mask = np.zeros(io.get_raster_shape(unw_filenames[0])[-2:], dtype="bool")
if temporal_coherence_filename:
bad_pixel_mask = bad_pixel_mask | (
io.load_gdal(temporal_coherence_filename) < 0.5
io.load_gdal(temporal_coherence_filename) < temporal_coherence_cutoff
)

if output_dir is None:
assert unw_filenames
output_dir = unw_filenames[0].parent
output_dir.mkdir(exist_ok=True)
ctx = mp.get_context("spawn")

with ProcessPoolExecutor(max_workers, mp_context=ctx) as pool:
num_parallel = min(max_workers, len(unw_filenames))
Executor = ProcessPoolExecutor if num_parallel > 1 else DummyProcessPoolExecutor
ctx = mp.get_context("spawn")
tqdm.set_lock(ctx.RLock())

with Executor(
max_workers=max_workers,
mp_context=ctx,
initializer=tqdm.set_lock,
initargs=(tqdm.get_lock(),),
) as exc:
return list(
pool.map(
exc.map(
_filter_and_save,
unw_filenames,
cor_filenames or repeat(None),
Expand All @@ -229,6 +245,7 @@ def filter_rasters(
repeat(wavelength_cutoff),
repeat(bad_pixel_mask),
repeat(correlation_cutoff),
repeat(fill_value),
)
)

Expand All @@ -241,11 +258,9 @@ def _filter_and_save(
wavelength_cutoff: float,
bad_pixel_mask: NDArray[np.bool_],
correlation_cutoff: float = 0.5,
fill_value: float | None = None,
) -> Path:
"""Filter one interferogram (wrapper for multiprocessing)."""
from dolphin import io
from dolphin._overviews import Resampling, create_image_overviews

# Average for the pixel spacing for filtering
_, x_res, _, _, _, y_res = io.get_raster_gt(unw_filename)
pixel_spacing = (abs(x_res) + abs(y_res)) / 2
Expand All @@ -262,6 +277,7 @@ def _filter_and_save(
bad_pixel_mask=bad_pixel_mask,
pixel_spacing=pixel_spacing,
workers=1,
fill_value=fill_value,
)
io.round_mantissa(filt_arr, keep_bits=9)
output_name = output_dir / Path(unw_filename).with_suffix(".filt.tif").name
Expand Down Expand Up @@ -300,8 +316,6 @@ def gaussian_filter_nan(
Filtered version of `image`.

"""
from scipy.ndimage import gaussian_filter

if np.sum(np.isnan(image)) == 0:
return gaussian_filter(image, sigma=sigma, mode=mode, **kwargs)

Expand Down
11 changes: 10 additions & 1 deletion src/dolphin/io/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"get_raster_gt",
"get_raster_metadata",
"get_raster_nodata",
"get_raster_shape",
"get_raster_units",
"get_raster_xysize",
"load_gdal",
Expand Down Expand Up @@ -257,13 +258,21 @@ def copy_projection(src_file: Filename, dst_file: Filename) -> None:


def get_raster_xysize(filename: Filename) -> tuple[int, int]:
"""Get the xsize/ysize of a GDAL-readable raster."""
"""Get the xsize, ysize of a GDAL-readable raster."""
ds = _get_gdal_ds(filename)
xsize, ysize = ds.RasterXSize, ds.RasterYSize
ds = None
return xsize, ysize


def get_raster_shape(filename: Filename) -> tuple[int, int, int]:
"""Get the (number of bands, rows, columns) of a GDAL-readable raster."""
import rasterio as rio

with rio.open(filename) as src:
return src.shape


def get_raster_nodata(filename: Filename, band: int = 1) -> Optional[float]:
"""Get the nodata value from a file.

Expand Down
Loading