From 1ba17b36a7f5e365560724eaca8eba322c22bc9f Mon Sep 17 00:00:00 2001 From: Arielle Leon Date: Mon, 9 Dec 2024 19:48:35 -0800 Subject: [PATCH] Revert to ecd29483bf8cf9a862d294eda331f3a60f87e4c1 --- code/registration.py | 693 ++++++++++++------------------------------- 1 file changed, 189 insertions(+), 504 deletions(-) diff --git a/code/registration.py b/code/registration.py index 32135a9..b8fd388 100644 --- a/code/registration.py +++ b/code/registration.py @@ -8,13 +8,14 @@ import tempfile import warnings from datetime import datetime as dt -from functools import lru_cache, partial +from functools import partial from glob import glob from itertools import product from multiprocessing import Pool from pathlib import Path from time import time from typing import Callable, Optional, Tuple, Union + import cv2 import h5py import matplotlib as mpl @@ -22,16 +23,15 @@ import pandas as pd import suite2p from aind_data_schema.core.processing import ( + Processing, DataProcess, PipelineProcess, - Processing, ProcessName, ) from aind_ophys_utils.array_utils import normalize_array -from aind_ophys_utils.video_utils import downsample_h5_video, encode_video, downsample_array +from aind_ophys_utils.video_utils import downsample_h5_video, encode_video from matplotlib import pyplot as plt # noqa: E402 from PIL import Image -from ScanImageTiffReader import ScanImageTiffReader from scipy.ndimage import median_filter from scipy.stats import sigmaclip from suite2p.registration.nonrigid import make_blocks @@ -60,165 +60,8 @@ def is_S3(file_path: str): ) -def h5py_byteorder_name(h5py_file: h5py.File, h5py_key: str) -> Tuple[str, str]: - """Get the byteorder and name of the dataset in the h5py file. - - Parameters - ---------- - h5py_file : h5py.File - h5py file object - h5py_key : str - key to the dataset - - Returns - ------- - str - byteorder of the dataset - str - name of the dataset - """ - with h5py.File(h5py_file, "r") as f: - byteorder = f[h5py_key].dtype.byteorder - name = f[h5py_key].dtype.name - return byteorder, name - - -def tiff_byteorder_name(tiff_file: Path) -> Tuple[str, str]: - """Get the byteorder and name of the dataset in the tiff file. - - Parameters - ---------- - tiff_file : Path - Location of the tiff file - - Returns - ------- - str - byteorder of the dataset - str - name of the dataset - """ - with ScanImageTiffReader(tiff_file) as reader: - byteorder = reader.data().dtype.byteorder - name = reader.data().dtype.name - return byteorder, name - - -def h5py_to_numpy( - h5py_file: str, - h5py_key: str, - trim_frames_start: int = 0, - trim_frames_end: int = 0, -) -> np.ndarray: - """Converts a h5py dataset to a numpy array - - Parameters - ---------- - h5py_file: str - h5py file path - h5py_key : str - key to the dataset - trim_frames_start : int - Number of frames to disregard from the start of the movie. Default 0. - trim_frames_end : int - Number of frames to disregard from the end of the movie. Default 0. - Returns - ------- - np.ndarray - numpy array - """ - with h5py.File(h5py_file, "r") as f: - n_frames = f[h5py_key].shape[0] - if trim_frames_start > 0 or trim_frames_end > 0: - return f[h5py_key][trim_frames_start : n_frames - trim_frames_end] - else: - return f[h5py_key][:] - -@lru_cache(maxsize=None) -def _tiff_to_numpy(tiff_file: Path) -> np.ndarray: - with ScanImageTiffReader(tiff_file) as reader: - return reader.data() - -def tiff_to_numpy( - tiff_list: List[Path], trim_frames_start: int = 0, trim_frames_end: int = 0 -) -> np.ndarray: - """ - Converts a list of TIFF files to a single numpy array, with optional frame trimming. - - Parameters - ---------- - tiff_list : List[str] - List of TIFF file paths to process - trim_frames_start : int, optional - Number of frames to remove from the start (default: 0) - trim_frames_end : int, optional - Number of frames to remove from the end (default: 0) - - Returns - ------- - np.ndarray - Combined array of all TIFF data with specified trimming - - Raises - ------ - ValueError - If trim values exceed total number of frames or are negative - """ - if trim_frames_start < 0 or trim_frames_end < 0: - raise ValueError("Trim values must be non-negative") - - def get_total_frames(tiff_files: List[Path]) -> int: - """Calculate total number of frames across all TIFF files.""" - total = 0 - for tiff in tiff_files: - with ScanImageTiffReader(tiff) as reader: - total += reader.shape()[0] - return total - - # Validate trim parameters - total_frames = get_total_frames(tiff_list) - if trim_frames_start + trim_frames_end >= total_frames: - raise ValueError( - f"Invalid trim values: start ({trim_frames_start}) + end ({trim_frames_end}) " - f"must be less than total frames ({total_frames})" - ) - - # Initialize variables for frame counting - processed_frames = 0 - arrays_to_stack = [] - - for tiff_path in tiff_list: - with ScanImageTiffReader(tiff_path) as reader: - current_frames = reader.shape()[0] - start_idx = max(0, trim_frames_start - processed_frames) - end_idx = current_frames - - if processed_frames + current_frames > total_frames - trim_frames_end: - end_idx = total_frames - trim_frames_end - processed_frames - - if start_idx < end_idx: # Only process if there are frames to include - data = np.array(_tiff_to_numpy(tiff_path)[start_idx:end_idx]) - arrays_to_stack.append(data) - - processed_frames += current_frames - - # Break if we've processed all needed frames - if processed_frames >= total_frames - trim_frames_end: - break - - # Stack all arrays along the appropriate axis - if not arrays_to_stack: - raise ValueError("No frames remained after trimming") - - return ( - np.concatenate(arrays_to_stack, axis=0) - if len(arrays_to_stack) > 1 - else arrays_to_stack[0] - ) - - def load_initial_frames( - file_path: Union[str, list], + file_path: str, h5py_key: str, n_frames: int, trim_frames_start: int = 0, @@ -245,18 +88,17 @@ def load_initial_frames( time axis. If n_frames > tot_frames, a number of frames equal to tot_frames is returned. """ - if isinstance(file_path, str): - array = h5py_to_numpy(file_path, h5py_key, trim_frames_start, trim_frames_end) - elif isinstance(file_path, list): - array = tiff_to_numpy(file_path, trim_frames_start, trim_frames_end) - else: - raise ValueError("File type not supported") - # Total number of frames in the movie. - tot_frames = array.shape[0] - requested_frames = np.linspace( - 0, tot_frames, 1 + min(n_frames, tot_frames), dtype=int - )[:-1] - frames = array[requested_frames] + with h5py.File(file_path, "r") as hdf5_file: + # Load all frames as fancy indexing is slower than loading the full + # data. + max_frame = hdf5_file[h5py_key].shape[0] - trim_frames_end + frame_window = hdf5_file[h5py_key][trim_frames_start:max_frame] + # Total number of frames in the movie. + tot_frames = frame_window.shape[0] + requested_frames = np.linspace( + 0, tot_frames, 1 + min(n_frames, tot_frames), dtype=int + )[:-1] + frames = frame_window[requested_frames] return frames @@ -500,7 +342,7 @@ def optimize_motion_parameters( ) start_time = time() for param_spatial, param_time in product(smooth_sigmas, smooth_sigma_times): - current_args = suite2p_parser.copy() + current_args = suite2p_args.copy() current_args["smooth_sigma"] = param_spatial current_args["smooth_sigma_time"] = param_time @@ -691,13 +533,13 @@ def add_modify_required_parameters(suite2p_args: dict): suite2p_args : dict Suite2p ops dictionary with potentially missing values. """ - if suite2p_parser.get("1Preg") is None: + if suite2p_args.get("1Preg") is None: suite2p_args["1Preg"] = False - if suite2p_parser.get("bidiphase") is None: + if suite2p_args.get("bidiphase") is None: suite2p_args["bidiphase"] = False - if suite2p_parser.get("nonrigid") is None: + if suite2p_args.get("nonrigid") is None: suite2p_args["nonrigid"] = False - if suite2p_parser.get("norm_frames") is None: + if suite2p_args.get("norm_frames") is None: suite2p_args["norm_frames"] = True # Don't use nonrigid for parameter search. suite2p_args["nonrigid"] = False @@ -737,52 +579,44 @@ def compute_acutance( return (grady**2 + gradx**2).mean() -def check_and_warn_on_datatype( - filepath: Path, logger: Callable, filetype: str = "h5", h5py_key: str = "" -): +def check_and_warn_on_datatype(h5py_name: str, h5py_key: str, logger: Callable): """Suite2p assumes int16 types throughout code. Check that the input data is type int16 else throw a warning. Parameters ---------- - filepath : Path + h5py_name : str Path to the HDF5 containing the data. + h5py_key : str + Name of the dataset to check. logger : Callable Logger to output logger warning to. - filetype : str - Type of file to check. Default is "h5". - h5py_key : str - Name of the dataset to check. Default is "". - """ - if filetype == "h5": - byteorder, name = h5py_byteorder_name(filepath, h5py_key) - elif filetype == "tiff": - byteorder, name = tiff_byteorder_name(filepath) - else: - raise ValueError("File type not supported") - if byteorder == ">": - logger( - "Data byteorder is big-endian which may cause issues in " - "suite2p. This may result in a crash or unexpected " - "results." - ) - if name != "int16": - logger( - f"Data type is {name} and not int16. Suite2p " - "assumes int16 data as input and throughout codebase. " - "Non-int16 data may result in unexpected results or " - "crashes." - ) + with h5py.File(h5py_name, "r") as h5_file: + dataset = h5_file[h5py_key] + if dataset.dtype.byteorder == ">": + logger( + "Data byteorder is big-endian which may cause issues in " + "suite2p. This may result in a crash or unexpected " + "results." + ) + if dataset.dtype.name != "int16": + logger( + f"Data type is {dataset.dtype.name} and not int16. Suite2p " + "assumes int16 data as input and throughout codebase. " + "Non-int16 data may result in unexpected results or " + "crashes." + ) -def _mean_of_batch(i, array): - return array[i : i + 1000].mean(axis=(1, 2)) + +def _mean_of_batch(i, h5py_name, h5py_key): + return h5py.File(h5py_name)[h5py_key][i : i + 1000].mean(axis=(1, 2)) def find_movie_start_end_empty_frames( - filepath: Union[str , list[str]], - h5py_key: str = "", + h5py_name: str, + h5py_key: str, n_sigma: float = 5, logger: Optional[Callable] = None, n_jobs: Optional[int] = None, @@ -796,10 +630,10 @@ def find_movie_start_end_empty_frames( Parameters ---------- - filepath : str | list[str] - File path to HDF5 file or the list of TIFFS to process + h5py_name : str + Name of the HDF5 file to load from. h5py_key : str - Name of the dataset to load from the HDF5 file. Default is "" + Name of the dataset to load from the HDF5 file. n_sigma : float Number of standard deviations beyond which a frame is considered an outlier and "empty". @@ -814,28 +648,22 @@ def find_movie_start_end_empty_frames( Tuple of the number of frames to cut from the start and end of the movie as (n_trim_start, n_trim_end). """ - - if isinstance(filepath, str): - array = h5py_to_numpy(filepath, h5py_key) - elif isinstance(filepath, list): - array = tiff_to_numpy(filepath) - else: - raise ValueError("File type not supported") # Find the midpoint of the movie. - n_frames = array.shape[0] - midpoint = n_frames // 2 - # We discover empty or extrema frames by comparing the mean of each frames - # to the mean of the full movie. - if n_jobs == 1 or n_frames < 2000: - means = array[:].mean(axis=(1, 2)) - else: - means = np.concatenate( - Pool(n_jobs).starmap( - _mean_of_batch, - product(range(0, n_frames, 1000), [array]), + with h5py.File(h5py_name, "r") as f: + n_frames = f[h5py_key].shape[0] + midpoint = n_frames // 2 + # We discover empty or extrema frames by comparing the mean of each frames + # to the mean of the full movie. + if n_jobs == 1 or n_frames < 2000: + means = f[h5py_key][:].mean(axis=(1, 2)) + else: + means = np.concatenate( + Pool(n_jobs).starmap( + _mean_of_batch, + product(range(0, n_frames, 1000), [h5py_name], [h5py_key]), + ) ) - ) - mean_of_frames = means.mean() + mean_of_frames = means.mean() # Compute a robust standard deviation that is not sensitive to the # outliers we are attempting to find. @@ -1038,7 +866,7 @@ def get_frame_rate_platform_json(input_dir: str) -> float: def write_output_metadata( metadata: dict, - raw_movie: Union[str, list], + raw_movie: Union[str, Path], motion_corrected_movie: Union[str, Path], output_dir: Union[str, Path], ) -> None: @@ -1053,10 +881,6 @@ def write_output_metadata( motion_corrected_movie: str path to motion corrected movies """ - if isinstance(raw_movie, Path): - raw_movie = str(raw_movie) - elif isinstance(raw_movie, list): - raw_movie = " ".join(raw_movie) processing = Processing( processing_pipeline=PipelineProcess( processor_full_name="Multplane Ophys Processing Pipeline", @@ -1083,8 +907,7 @@ def write_output_metadata( output_dir = Path(output_dir) print(f"~~~~~~~~~~~~~~Writing output: {output_dir}") processing.write_standard_file(output_directory=output_dir) - - + def check_trim_frames(data): """Make sure that if the user sets auto_remove_empty_frames and timing frames is already requested, raise an error. @@ -1198,7 +1021,7 @@ def make_nonrigid_png( def downsample_normalize( - movie_path: Union[Path, np.array], + movie_path: Path, frame_rate: float, bin_size: float, lower_quantile: float, @@ -1210,9 +1033,9 @@ def downsample_normalize( Parameters ---------- - movie_path: Union[Path, np.array] + movie_path: Path path to an h5 file, containing an (nframes x nrows x ncol) dataset - named 'data' or a numpy array. + named 'data' frame_rate: float frame rate of the movie specified by 'movie_path' bin_size: float @@ -1236,10 +1059,7 @@ def downsample_normalize( consistent visibility. """ - if isinstance(movie_path, Path): - ds = downsample_h5_video(movie_path, input_fps=frame_rate, output_fps=1.0 / bin_size) - else: - ds = downsample_array(movie_path, input_fps=frame_rate, output_fps=1.0 / bin_size) + ds = downsample_h5_video(movie_path, input_fps=frame_rate, output_fps=1.0 / bin_size) avg_projection = ds.mean(axis=0) lower_cutoff, upper_cutoff = np.quantile( avg_projection.flatten(), (lower_quantile, upper_quantile) @@ -1353,15 +1173,13 @@ def multiplane_motion_correction(data_dir: Path, output_dir: Path, debug: bool = frame_rate_hz: float frame rate in Hz """ - data_dir = next(data_dir.rglob("pophys")) - if not data_dir.is_dir(): - raise ValueError("Could not locate 'pophys' directory") try: unique_id = [i for i in data_dir.rglob("*") if "ophys_experiment" in str(i)][ 0 ].name.split("_")[-1] h5_file = [i for i in data_dir.rglob("*") if f"{unique_id}.h5" in str(i)][0] except IndexError: + unique_id = [i for i in data_dir.rglob("*") if i.is_dir()][0].name h5_file = [i for i in data_dir.rglob("*") if f"{unique_id}.h5" in str(i)][0] session_dir = h5_file.parent.parent @@ -1415,15 +1233,9 @@ def update_suite2p_args_reference_image( ) else: - if suite2p_args.get("h5py", None): - file_path = suite2p_args["h5py"] - h5py_key = suite2p_args["h5py_key"] - else: - file_path = suite2p_args["tiff_list"] - h5py_key = None initial_frames = load_initial_frames( - file_path=file_path, - h5py_key=h5py_key, + file_path=suite2p_args["h5py"], + h5py_key=suite2p_args["h5py_key"], n_frames=suite2p_args["nimg_init"], trim_frames_start=args["trim_frames_start"], trim_frames_end=args["trim_frames_end"], @@ -1583,33 +1395,12 @@ def singleplane_motion_correction( return h5_file, output_dir, reference_image_fp -def get_frame_rate(session: dict): - """Attempt to pull frame rate from session.json - Returns none if frame rate not in session.json - - Parameters - ---------- - session: dict - session metadata - - Returns - ------- - frame_rate: float - frame rate in Hz - """ - frame_rate_hz = None - for i in session.get("data_streams", ""): - frame_rate_hz = [j["frame_rate"] for j in i["ophys_fovs"]] - frame_rate_hz = frame_rate_hz[0] - if frame_rate_hz: - break - if isinstance(frame_rate_hz, str): - frame_rate_hz = float(frame_rate_hz) - return frame_rate_hz - +if __name__ == "__main__": # pragma: nocover + # Set the log level and name the logger + logger = logging.getLogger("Suite2P motion correction") + logger.setLevel(logging.INFO) -def parse_arguments(): - """Parse command-line arguments""" + # Create an ArgumentParser object parser = argparse.ArgumentParser(description="Suite2P motion correction") parser.add_argument( @@ -1624,27 +1415,9 @@ def parse_arguments(): ) parser.add_argument( - "-d", "--debug", action="store_true", help="Run with only first 500 frames" - ) - - parser.add_argument( - "--frame-rate", - type=float, - default=31.0, - help="Frame rate of the movie in Hz. If not provided, " - "the frame rate will here will be used.", + "-d", "--debug", action="store_true", help="Run with only partial dset" ) - parser.add_argument( - "--data-type", type=str, default="h5", help="Specify h5 or TIFF input type" - ) - - parser.add_argument( - "--look-one-level-down", - type=bool, - default=False, - help="If True, search for TIFF files in subdirectories " "of the input directory", - ) parser.add_argument( "--tmp_dir", type=str, @@ -1776,56 +1549,45 @@ def parse_arguments(): "steps=1.", ) - return parser.parse_args() - - -if __name__ == "__main__": # pragma: nocover - # Set the log level and name the logger - logger = logging.getLogger("Suite2P motion correction") - logger.setLevel(logging.INFO) - - # Create an ArgumentParser object - parser = parse_arguments() + # Parse command-line arguments + args = parser.parse_args() # General settings - data_dir = Path(parser.input) - output_dir = Path(parser.output_dir) + output_dir = Path(args.output_dir) + data_dir = Path("../data") session_fp = next(data_dir.rglob("session.json")) description_fp = next(data_dir.rglob("data_description.json")) with open(session_fp, "r") as j: session = json.load(j) with open(description_fp, "r") as j: data_description = json.load(j) - frame_rate_hz = get_frame_rate(session) - unique_id = "_".join(str(data_description["name"]).split("_")[-3:]) + for i in session["data_streams"]: + frame_rate_hz = [j["frame_rate"] for j in i["ophys_fovs"]] + if frame_rate_hz: + break + frame_rate_hz = frame_rate_hz[0] + if isinstance(frame_rate_hz, str): + frame_rate_hz = float(frame_rate_hz) reference_image_fp = "" - - if parser.data_type == "TIFF": - input_file = next(data_dir.rglob("pophys")) + unique_id = "_".join(str(data_description["name"]).split("_")[-3:]) + if "Bergamo" in session["rig_id"]: + h5_file, output_dir, reference_image_fp = singleplane_motion_correction( + data_dir, output_dir, session, unique_id, debug=args.debug + ) else: - if "Bergamo" in session.get("rig_id", ""): - h5_file, output_dir, reference_image_fp = singleplane_motion_correction( - data_dir, output_dir, session, unique_id, debug=parser.debug - ) - else: - h5_file, output_dir, frame_rate_hz = multiplane_motion_correction( - data_dir, output_dir, debug=parser.debug - ) - input_file = str(h5_file) + h5_file, output_dir, frame_rate_hz = multiplane_motion_correction( + data_dir, output_dir, debug=args.debug + ) + # We convert to dictionary - args = vars(parser) - if not frame_rate_hz: - frame_rate_hz = parser.frame_rate - logging.warning("Using default frame rate of 31.0 Hz") + args = vars(args) + h5_file = str(h5_file) reference_image = None + meta_jsons = list(data_dir.glob("*/*.json")) args["refImg"] = [] if reference_image_fp: args["refImg"] = [reference_image_fp] # We construct the paths to the outputs - if isinstance(input_file, list): - basename = os.path.basename(input_file[0]) - else: - basename = os.path.basename(input_file) args["movie_frame_rate_hz"] = frame_rate_hz for key, default in ( ("motion_corrected_output", "_registered.h5"), @@ -1836,7 +1598,9 @@ def parse_arguments(): ("motion_correction_preview_output", "_motion_preview.webm"), ("output_json", "_motion_correction_output.json"), ): - args[key] = os.path.join(output_dir, os.path.splitext(basename)[0] + default) + args[key] = os.path.join( + output_dir, os.path.splitext(os.path.basename(h5_file))[0] + default + ) # These are hardcoded parameters of the wrapper. Those are tracked but # not exposed. @@ -1869,24 +1633,19 @@ def parse_arguments(): # This is part of a complex scheme to pass an image that is a bit too # complicated. Will remove when tested. - # if not parser.get("refImg", ""): + # if not args.get("refImg", ""): # args["refImg"] = [] - # Set suite2p parser. + # Set suite2p args. suite2p_args = suite2p.default_ops() # Here we overwrite the parameters for suite2p that will not change in our # processing pipeline. These are parameters that are not exposed to # minimize code length. Those are not set to default. - if parser.data_type == "h5": - suite2p_args["h5py"] = input_file - else: - suite2p_args["data_path"] = str(input_file) - suite2p_args["look_one_level_down"] = True - suite2p_args["tiff_list"] = [str(i) for i in input_file.glob("*.tif*")] + suite2p_args["h5py"] = h5_file suite2p_args["roidetect"] = False suite2p_args["do_registration"] = 1 - # suite2p_args["data_path"] = [] # TODO: remove this if not needed by suite2p + suite2p_args["data_path"] = [] # TODO: remove this if not needed by suite2p suite2p_args["reg_tif"] = False # We save our own outputs here suite2p_args["nimg_init"] = ( 500 # Nb of images to compute reference. This value is a bit high. Suite2p has it at 300 normally @@ -1901,8 +1660,7 @@ def parse_arguments(): 5.0 # Maximum shift allowed in pixels for a block in rigid registration. ) suite2p_args["batch_size"] = 500 # Number of frames to process at once - if suite2p_args.get("h5py", ""): - suite2p_args["h5py_key"] = "data" # h5 path in the file. + suite2p_args["h5py_key"] = "data" # h5 path in the file. suite2p_args["smooth_sigma"] = ( 1.15 # Standard deviation in pixels of the gaussian used to smooth the phase correlation. ) @@ -1920,38 +1678,25 @@ def parse_arguments(): suite2p_args["force_refImg"] = args["force_refImg"] # if data is in a S3 bucket, copy it to /scratch for faster access - if suite2p_args.get("h5py", ""): - if is_S3(suite2p_args["h5py"]): - dst = "/scratch/" + Path(suite2p_args["h5py"]).name - logger.info(f"copying {suite2p_args['h5py']} from S3 bucket to {dst}") - shutil.copy(suite2p_args["h5py"], dst) - suite2p_args["h5py"] = dst - - if suite2p_args.get("tiff_list", ""): - check_and_warn_on_datatype( - filepath=suite2p_args["tiff_list"][0], logger=logger.warning, filetype="tiff" - ) - else: - check_and_warn_on_datatype( - filepath=suite2p_args["h5py"], - logger=logger.warning, - filetype="h5", - h5py_key=suite2p_args["h5py_key"], - ) + if is_S3(suite2p_args["h5py"]): + dst = "/scratch/" + Path(suite2p_args["h5py"]).name + logger.info(f"copying {suite2p_args['h5py']} from S3 bucket to {dst}") + shutil.copy(suite2p_args["h5py"], dst) + suite2p_args["h5py"] = dst + + check_and_warn_on_datatype( + h5py_name=suite2p_args["h5py"], + h5py_key=suite2p_args["h5py_key"], + logger=logger.warning, + ) if args["auto_remove_empty_frames"]: logger.info("Attempting to find empty frames at the start and end of the movie.") - if suite2p_args.get("tiff_list", ""): - lowside, highside = find_movie_start_end_empty_frames( - filepath=suite2p_args["tiff_list"], - logger=logger.warning, - ) - else: - lowside, highside = find_movie_start_end_empty_frames( - filepath=suite2p_args["h5py"], - h5py_key=suite2p_args["h5py_key"], - logger=logger.warning, - ) + lowside, highside = find_movie_start_end_empty_frames( + h5py_name=suite2p_args["h5py"], + h5py_key=suite2p_args["h5py_key"], + logger=logger.warning, + ) args["trim_frames_start"] = lowside args["trim_frames_end"] = highside logger.info(f"Found ({lowside}, {highside}) at the start/end of the movie.") @@ -1966,8 +1711,8 @@ def parse_arguments(): suite2p_args, args, reference_image_fp=reference_image_fp ) - # register with Suite2P - logger.info(f"attempting to motion correct {suite2p_args['h5py']}") + # register with Suite2P + logger.info(f"attempting to motion correct {suite2p_args['h5py']}") # make a tempdir for Suite2P's output tmp_dir = tempfile.TemporaryDirectory(dir=args["tmp_dir"]) tdir = tmp_dir.name @@ -1989,17 +1734,10 @@ def parse_arguments(): if suite2p_args["force_refImg"]: logger.info(f"\tUsing custom reference image: {suite2p_args['refImg']}") - if suite2p_args.get("h5py", ""): - suite2p_args["h5py"] = suite2p_args["h5py"] + suite2p_args["h5py"] = [suite2p_args["h5py"]] suite2p.run_s2p(suite2p_args) - data_path = "" - if suite2p_args.get("h5py", ""): - data_path = suite2p_args["h5py"][0] - else: - data_path = suite2p_args["data_path"][0] + suite2p_args["h5py"] = suite2p_args["h5py"][0] - if not data_path: - raise ValueError("No data path found in suite2p_args") bin_path = list(Path(tdir).rglob("data.bin"))[0] ops_path = list(Path(tdir).rglob("ops.npy"))[0] # Suite2P ops file contains at least the following keys: @@ -2092,12 +1830,8 @@ def parse_arguments(): # make projections mx_proj = projection_process(data, projection="max") av_proj = projection_process(data, projection="avg") - if not suite2p_args.get("h5py", []): - filepath = suite2p_args["tiff_list"] - else: - filepath = suite2p_args["h5py"] write_output_metadata( - args_copy, suite2p_args["h5py"], args["motion_corrected_output"], output_dir + args_copy, Path(suite2p_args["h5py"]), args["motion_corrected_output"], output_dir ) # TODO: normalize here, if desired # save projections @@ -2215,25 +1949,13 @@ def parse_arguments(): lower_quantile=args["movie_lower_quantile"], upper_quantile=args["movie_upper_quantile"], ) - if suite2p_args.get("h5py", ""): - h5_file = suite2p_args["h5py"] - processed_vids = [ - ds_partial(i) - for i in [ - Path(h5_file), - Path(args["motion_corrected_output"]), - ] - ] - else: - tiff_array = tiff_to_numpy(suite2p_args["tiff_list"]) - processed_vids = [ - ds_partial(i) - for i in [ - tiff_array, - Path(args["motion_corrected_output"]), - ] + processed_vids = [ + ds_partial(i) + for i in [ + Path(h5_file), + Path(args["motion_corrected_output"]), ] - + ] logger.info("finished downsampling motion corrected and non-motion corrected movies") # tile into 1 movie, raw on left, motion corrected on right @@ -2247,100 +1969,63 @@ def parse_arguments(): except: logger.info("Could not write motion correction preview") # compute crispness of mean image using raw and registered movie - if suite2p_args.get("h5py", ""): - with ( - h5py.File(h5_file) as f_raw, - h5py.File(args["motion_corrected_output"], "r+") as f, - ): - mov_raw = f_raw["data"] - mov = f["data"] - crispness = [ - np.sqrt(np.sum(np.array(np.gradient(np.mean(m, 0))) ** 2)) - for m in (mov_raw, mov) - ] - logger.info("computed crispness of mean image before and after registration") - - # compute residual optical flow using Farneback method - if f["reg_metrics/regPC"][:].any(): - regPC = f["reg_metrics/regPC"] - flows = np.zeros(regPC.shape[1:] + (2,), np.float32) - for i in range(len(flows)): - pclow, pchigh = regPC[:, i] - flows[i] = cv2.calcOpticalFlowFarneback( - pclow, - pchigh, - None, - pyr_scale=0.5, - levels=3, - winsize=100, - iterations=15, - poly_n=5, - poly_sigma=1.2 / 5, - flags=0, - ) - flows_norm = np.sqrt(np.sum(flows**2, -1)) - farnebackDX = np.transpose([flows_norm.mean((1, 2)), flows_norm.max((1, 2))]) - f.create_dataset("reg_metrics/crispness", data=crispness) - f.create_dataset("reg_metrics/farnebackROF", data=flows) - f.create_dataset("reg_metrics/farnebackDX", data=farnebackDX) - logger.info( - "computed residual optical flow of top PCs using Farneback method" - ) - logger.info( - "appended additional registration metrics to" - f"{args['motion_corrected_output']}" + with ( + h5py.File(h5_file) as f_raw, + h5py.File(args["motion_corrected_output"], "r+") as f, + ): + mov_raw = f_raw["data"] + mov = f["data"] + crispness = [ + np.sqrt(np.sum(np.array(np.gradient(np.mean(m, 0))) ** 2)) + for m in (mov_raw, mov) + ] + logger.info("computed crispness of mean image before and after registration") + + # compute residual optical flow using Farneback method + if f["reg_metrics/regPC"][:].any(): + regPC = f["reg_metrics/regPC"] + flows = np.zeros(regPC.shape[1:] + (2,), np.float32) + for i in range(len(flows)): + pclow, pchigh = regPC[:, i] + flows[i] = cv2.calcOpticalFlowFarneback( + pclow, + pchigh, + None, + pyr_scale=0.5, + levels=3, + winsize=100, + iterations=15, + poly_n=5, + poly_sigma=1.2 / 5, + flags=0, ) - else: - mov_raw = tiff_array - with h5py.File(args["motion_corrected_output"], "r+") as f: - crispness = [ - np.sqrt(np.sum(np.array(np.gradient(np.mean(m, 0))) ** 2)) - for m in (mov_raw, f["data"]) - ] - logger.info("computed crispness of mean image before and after registration") - if f["reg_metrics/regPC"][:].any(): - regPC = f["reg_metrics/regPC"] - flows = np.zeros(regPC.shape[1:] + (2,), np.float32) - for i in range(len(flows)): - pclow, pchigh = regPC[:, i] - flows[i] = cv2.calcOpticalFlowFarneback( - pclow, - pchigh, - None, - pyr_scale=0.5, - levels=3, - winsize=100, - iterations=15, - poly_n=5, - poly_sigma=1.2 / 5, - flags=0, - ) - flows_norm = np.sqrt(np.sum(flows**2, -1)) - farnebackDX = np.transpose([flows_norm.mean((1, 2)), flows_norm.max((1, 2))]) - f.create_dataset("reg_metrics/crispness", data=crispness) - f.create_dataset("reg_metrics/farnebackROF", data=flows) - f.create_dataset("reg_metrics/farnebackDX", data=farnebackDX) - logger.info( - "computed residual optical flow of top PCs using Farneback method" + flows_norm = np.sqrt(np.sum(flows**2, -1)) + farnebackDX = np.transpose([flows_norm.mean((1, 2)), flows_norm.max((1, 2))]) + f.create_dataset("reg_metrics/crispness", data=crispness) + f.create_dataset("reg_metrics/farnebackROF", data=flows) + f.create_dataset("reg_metrics/farnebackDX", data=farnebackDX) + logger.info( + "computed residual optical flow of top PCs using Farneback method" + ) + logger.info( + "appended additional registration metrics to" + f"{args['motion_corrected_output']}" + ) + + # create image of PC_low, PC_high, and the residual optical flow between them + if f["reg_metrics/regDX"][:].any(): + for iPC in set( + ( + np.argmax(f["reg_metrics/regDX"][:, -1]), + np.argmax(farnebackDX[:, -1]), ) - logger.info( - "appended additional registration metrics to" - f"{args['motion_corrected_output']}" + ): + p = Path(args["registration_summary_output"]) + flow_png( + Path(args["motion_corrected_output"]), + str(p.parent / p.stem), + iPC, ) - # create image of PC_low, PC_high, and the residual optical flow between them - if f["reg_metrics/regDX"][:].any(): - for iPC in set( - ( - np.argmax(f["reg_metrics/regDX"][:, -1]), - np.argmax(farnebackDX[:, -1]), - ) - ): - p = Path(args["registration_summary_output"]) - flow_png( - Path(args["motion_corrected_output"]), - str(p.parent / p.stem), - iPC, - ) logger.info(f"created images of PC_low, PC_high, and PC_rof for PC {iPC}") # Clean up temporary directory