diff --git a/src/toast/intervals.py b/src/toast/intervals.py index 4b54f2727..f00855fd7 100644 --- a/src/toast/intervals.py +++ b/src/toast/intervals.py @@ -69,15 +69,18 @@ def __init__(self, timestamps, intervals=None, timespans=None, samplespans=None) raise RuntimeError( "If constructing from intervals, other spans should be None" ) - timespans = [(x.start, x.stop) for x in intervals] - indices = self._find_indices(timespans) - self.data = np.array( - [ - (self.timestamps[x[0]], self.timestamps[x[1]], x[0], x[1]) - for x in indices - ], - dtype=interval_dtype, - ).view(np.recarray) + if len(intervals) == 0: + self.data = np.zeros(0, dtype=interval_dtype).view(np.recarray) + else: + timespans = [(x.start, x.stop) for x in intervals] + indices = self._find_indices(timespans) + self.data = np.array( + [ + (self.timestamps[x[0]], self.timestamps[x[1]], x[0], x[1]) + for x in indices + ], + dtype=interval_dtype, + ).view(np.recarray) elif timespans is not None: if samplespans is not None: raise RuntimeError("Cannot construct from both time and sample spans") diff --git a/src/toast/ops/CMakeLists.txt b/src/toast/ops/CMakeLists.txt index 2624f3fcc..34fcdf3a5 100644 --- a/src/toast/ops/CMakeLists.txt +++ b/src/toast/ops/CMakeLists.txt @@ -8,6 +8,7 @@ install(FILES delete.py copy.py reset.py + fill_gaps.py arithmetic.py memory_counter.py simple_deglitch.py diff --git a/src/toast/ops/__init__.py b/src/toast/ops/__init__.py index b0021d3a1..5cc6801da 100644 --- a/src/toast/ops/__init__.py +++ b/src/toast/ops/__init__.py @@ -15,6 +15,7 @@ from .delete import Delete from .demodulation import Demodulate, StokesWeightsDemod from .elevation_noise import ElevationNoise +from .fill_gaps import FillGaps from .filterbin import FilterBin, combine_observation_matrix from .flag_intervals import FlagIntervals from .flag_sso import FlagSSO diff --git a/src/toast/ops/azimuth_intervals.py b/src/toast/ops/azimuth_intervals.py index 8cf87b9ef..59c3428a1 100644 --- a/src/toast/ops/azimuth_intervals.py +++ b/src/toast/ops/azimuth_intervals.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2023 by the parties listed in the AUTHORS file. +# Copyright (c) 2023-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -16,7 +16,7 @@ from ..observation import default_values as defaults from ..timing import Timer, function_timer from ..traits import Bool, Float, Int, Quantity, Unicode, trait_docs -from ..utils import Environment, Logger, rate_from_times +from ..utils import Environment, Logger, rate_from_times, flagged_noise_fill from ..vis import set_matplotlib_backend from .flag_intervals import FlagIntervals from .operator import Operator @@ -109,7 +109,7 @@ class AzimuthIntervals(Operator): help="Bit mask value for bad azimuth pointing", ) - window_seconds = Float(0.25, help="Smoothing window in seconds") + window_seconds = Float(0.5, help="Smoothing window in seconds") debug_root = Unicode( None, @@ -142,30 +142,31 @@ def _exec(self, data, detectors=None, **kwargs): stable_times = None stable_leftright_times = None stable_rightleft_times = None + have_scanning = True # Sample rate stamps = obs.shared[self.times].data (rate, dt, dt_min, dt_max, dt_std) = rate_from_times(stamps) + # Smoothing window in samples + window = int(rate * self.window_seconds) + if obs.comm_col_rank == 0: - # Smoothing window in samples - window = int(rate * self.window_seconds) + # The azimuth angle + azimuth = np.array(obs.shared[self.azimuth].data) - # The scan velocity - scan_vel = np.gradient(obs.shared[self.azimuth].data) + # The azimuth flags + flags = np.array(obs.shared[self.shared_flags].data) + flags &= self.shared_flag_mask - # Smooth with moving window - wscan_vel = uniform_filter1d(scan_vel, size=window, mode="nearest") + # Scan velocity + scan_vel = self._gradient(azimuth, window, flags=flags) # The peak to peak range of the scan velocity - vel_range = np.amax(wscan_vel) - np.amin(wscan_vel) + vel_range = np.amax(scan_vel) - np.amin(scan_vel) - # The smoothed scan acceleration - scan_accel = uniform_filter1d( - np.gradient(wscan_vel), - size=window, - mode="nearest", - ) + # Scan acceleration + scan_accel = self._gradient(scan_vel, window) # Peak to peak acceleration range accel_range = np.amax(scan_accel) - np.amin(scan_accel) @@ -176,8 +177,10 @@ def _exec(self, data, detectors=None, **kwargs): stable = (np.absolute(scan_accel) < 0.1 * accel_range) * np.ones( len(scan_accel), dtype=np.int8 ) - stable *= np.absolute(wscan_vel) > 0.1 * vel_range + stable *= np.absolute(scan_vel) > 0.1 * vel_range + # The first estimate of the samples where stable pointing + # begins and ends. begin_stable = np.where(stable[1:] - stable[:-1] == 1)[0] end_stable = np.where(stable[:-1] - stable[1:] == 1)[0] @@ -187,110 +190,134 @@ def _exec(self, data, detectors=None, **kwargs): msg += f" change the filter window. Flagging all samples" msg += f" as unstable pointing." log.warning(msg) - continue - - if begin_stable[0] > end_stable[0]: - # We start in the middle of a scan - begin_stable = np.concatenate(([0], begin_stable)) - if begin_stable[-1] > end_stable[-1]: - # We end in the middle of a scan - end_stable = np.concatenate((end_stable, [obs.n_local_samples])) - stable_times = [ - (stamps[x[0]], stamps[x[1] - 1]) - for x in zip(begin_stable, end_stable) - ] - - # In some situations there are very short stable scans detected at the - # beginning and end of observations. Here we cut any short throw and - # stable periods. - if self.cut_short: - stable_spans = np.array([(x[1] - x[0]) for x in stable_times]) - try: - # First try short limit as time - stable_bad = stable_spans < self.short_limit.to_value(u.s) - except: - # Try short limit as fraction - median_stable = np.median(stable_spans) - stable_bad = stable_spans < self.short_limit * median_stable - begin_stable = np.array( - [x for (x, y) in zip(begin_stable, stable_bad) if not y] - ) - end_stable = np.array( - [x for (x, y) in zip(end_stable, stable_bad) if not y] - ) - stable_times = [ - x for (x, y) in zip(stable_times, stable_bad) if not y - ] - if self.cut_long: - stable_spans = np.array([(x[1] - x[0]) for x in stable_times]) - try: - # First try long limit as time - stable_bad = stable_spans > self.long_limit.to_value(u.s) - except: - # Try long limit as fraction - median_stable = np.median(stable_spans) - stable_bad = stable_spans > self.long_limit * median_stable - begin_stable = np.array( - [x for (x, y) in zip(begin_stable, stable_bad) if not y] - ) - end_stable = np.array( - [x for (x, y) in zip(end_stable, stable_bad) if not y] - ) - stable_times = [ - x for (x, y) in zip(stable_times, stable_bad) if not y - ] + have_scanning = False + + if have_scanning: + # Refine our list of stable periods + if begin_stable[0] > end_stable[0]: + # We start in the middle of a scan + begin_stable = np.concatenate(([0], begin_stable)) + if begin_stable[-1] > end_stable[-1]: + # We end in the middle of a scan + end_stable = np.concatenate((end_stable, [obs.n_local_samples])) + + # In some situations there are very short stable scans detected at + # the beginning and end of observations. Here we cut any short + # throw and stable periods. + cut_threshold = 4 + if (self.cut_short or self.cut_long) and ( + len(begin_stable) >= cut_threshold + ): + if self.cut_short: + stable_timespans = np.array( + [ + stamps[y - 1] - stamps[x] + for x, y in zip(begin_stable, end_stable) + ] + ) + try: + # First try short limit as time + stable_bad = ( + stable_timespans < self.short_limit.to_value(u.s) + ) + except: + # Try short limit as fraction + median_stable = np.median(stable_timespans) + stable_bad = ( + stable_timespans < self.short_limit * median_stable + ) + begin_stable = np.array( + [x for (x, y) in zip(begin_stable, stable_bad) if not y] + ) + end_stable = np.array( + [x for (x, y) in zip(end_stable, stable_bad) if not y] + ) + if self.cut_long: + stable_timespans = np.array( + [ + stamps[y - 1] - stamps[x] + for x, y in zip(begin_stable, end_stable) + ] + ) + try: + # First try long limit as time + stable_bad = ( + stable_timespans > self.long_limit.to_value(u.s) + ) + except: + # Try long limit as fraction + median_stable = np.median(stable_timespans) + stable_bad = ( + stable_timespans > self.long_limit * median_stable + ) + begin_stable = np.array( + [x for (x, y) in zip(begin_stable, stable_bad) if not y] + ) + end_stable = np.array( + [x for (x, y) in zip(end_stable, stable_bad) if not y] + ) + if len(begin_stable) == 0: + have_scanning = False # The "throw" intervals extend from one turnaround to the next. # We start the first throw at the beginning of the first stable scan # and then find the sample between stable scans where the turnaround # happens. This reduces false detections of turnarounds before or # after the stable scanning within the observation. + # + # If no turnaround is found between stable scans, we log a warning + # and choose the sample midway between stable scans to be the throw + # boundary. + if have_scanning: + begin_throw = [begin_stable[0]] + end_throw = list() + vel_switch = list() + for start_turn, end_turn in zip(end_stable[:-1], begin_stable[1:]): + # Fit a quadratic polynomial and find the velocity change sample + vel_turn = self._find_turnaround(scan_vel[start_turn:end_turn]) + if vel_turn is None: + msg = f"{obs.name}: Turnaround not found between" + msg += " end of stable scan at" + msg += f" sample {start_turn} and next start at" + msg += f" {end_turn}. Selecting midpoint as turnaround." + log.warning(msg) + half_gap = (end_turn - start_turn) // 2 + end_throw.append(start_turn + half_gap) + else: + end_throw.append(start_turn + vel_turn) + vel_switch.append(end_throw[-1]) + begin_throw.append(end_throw[-1] + 1) + end_throw.append(end_stable[-1]) + begin_throw = np.array(begin_throw) + end_throw = np.array(end_throw) + vel_switch = np.array(vel_switch) - begin_throw = [begin_stable[0]] - end_throw = list() - for start_turn, end_turn in zip(end_stable[:-1], begin_stable[1:]): - vel_switch = np.where( - wscan_vel[start_turn : end_turn - 1] - * wscan_vel[start_turn + 1 : end_turn] - < 0 - )[0] - if len(vel_switch) > 1: - msg = f"{obs.name}: Multiple turnarounds between end of " - msg += "stable scan at" - msg += f" sample {start_turn} and next start at {end_turn}." - msg += " Cutting ." - log.warning(msg) - break - end_throw.append(start_turn + vel_switch[0]) - begin_throw.append(end_throw[-1] + 1) - end_throw.append(end_stable[-1]) - begin_throw = np.array(begin_throw) - end_throw = np.array(end_throw) - - throw_times = [ - (stamps[x[0]], stamps[x[1] - 1]) - for x in zip(begin_throw, end_throw) - ] + stable_times = [ + (stamps[x[0]], stamps[x[1] - 1]) + for x in zip(begin_stable, end_stable) + ] + throw_times = [ + (stamps[x[0]], stamps[x[1] - 1]) + for x in zip(begin_throw, end_throw) + ] - # Split scans into left and right-going intervals - stable_leftright_times = [] - stable_rightleft_times = [] - throw_leftright_times = [] - throw_rightleft_times = [] - - for iscan, (first, last) in enumerate(zip(begin_stable, end_stable)): - # Check the velocity at the middle of the scan - mid = first + (last - first) // 2 - if wscan_vel[mid] > 0: - stable_leftright_times.append(stable_times[iscan]) - throw_leftright_times.append(throw_times[iscan]) - elif wscan_vel[mid] < 0: - stable_rightleft_times.append(stable_times[iscan]) - throw_rightleft_times.append(throw_times[iscan]) - else: - msg = "Velocity is zero in the middle of scan" - msg += f" samples {first} ... {last}" - raise RuntimeError(msg) + throw_leftright_times = list() + throw_rightleft_times = list() + stable_leftright_times = list() + stable_rightleft_times = list() + + # Split scans into left and right-going intervals + for iscan, (first, last) in enumerate( + zip(begin_stable, end_stable) + ): + # Check the velocity at the middle of the scan + mid = first + (last - first) // 2 + if scan_vel[mid] >= 0: + stable_leftright_times.append(stable_times[iscan]) + throw_leftright_times.append(throw_times[iscan]) + else: + stable_rightleft_times.append(stable_times[iscan]) + throw_rightleft_times.append(throw_times[iscan]) if self.debug_root is not None: set_matplotlib_backend() @@ -298,87 +325,188 @@ def _exec(self, data, detectors=None, **kwargs): import matplotlib.pyplot as plt # Dump some plots - out_file = f"{self.debug_root}_{obs.comm_row_rank}.pdf" - if len(end_throw) >= 5: - # Plot a few scans - n_plot = end_throw[4] + out_file = f"{self.debug_root}_{obs.name}_{obs.comm_row_rank}.pdf" + if have_scanning: + if len(end_throw) >= 5: + # Plot a few scans + plot_start = 0 + n_plot = end_throw[4] + else: + # Plot it all + plot_start = 0 + n_plot = obs.n_local_samples + pslc = slice(plot_start, plot_start + n_plot, 1) + px = np.arange(plot_start, plot_start + n_plot, 1) + + swplot = vel_switch[ + np.logical_and( + vel_switch <= plot_start + n_plot, + vel_switch >= plot_start, + ) + ] + bstable = begin_stable[ + np.logical_and( + begin_stable <= plot_start + n_plot, + begin_stable >= plot_start, + ) + ] + estable = end_stable[ + np.logical_and( + end_stable <= plot_start + n_plot, + end_stable >= plot_start, + ) + ] + bthrow = begin_throw[ + np.logical_and( + begin_throw <= plot_start + n_plot, + begin_throw >= plot_start, + ) + ] + ethrow = end_throw[ + np.logical_and( + end_throw <= plot_start + n_plot, + end_throw >= plot_start, + ) + ] + + fig = plt.figure(dpi=100, figsize=(8, 16)) + + ax = fig.add_subplot(4, 1, 1) + ax.plot(px, azimuth[pslc], "-", label="Azimuth") + ax.legend(loc="best") + ax.set_xlabel("Samples") + ax.set_ylabel("Azimuth (Radians)") + + ax = fig.add_subplot(4, 1, 2) + ax.plot(px, stable[pslc], "-", label="Stable Pointing") + ax.plot(px, flags[pslc], color="black", label="Flags") + ax.vlines( + bstable, + ymin=-1, + ymax=2, + color="green", + label="Begin Stable", + ) + ax.vlines( + estable, ymin=-1, ymax=2, color="red", label="End Stable" + ) + ax.vlines( + bthrow, ymin=-2, ymax=1, color="cyan", label="Begin Throw" + ) + ax.vlines( + ethrow, ymin=-2, ymax=1, color="purple", label="End Throw" + ) + ax.legend(loc="best") + ax.set_xlabel("Samples") + ax.set_ylabel("Stable Scan / Throw") + + ax = fig.add_subplot(4, 1, 3) + ax.plot(px, scan_vel[pslc], "-", label="Velocity") + ax.vlines( + swplot, + ymin=np.amin(scan_vel), + ymax=np.amax(scan_vel), + color="red", + label="Velocity Switch", + ) + ax.legend(loc="best") + ax.set_xlabel("Samples") + ax.set_ylabel("Scan Velocity (Radians / s)") + + ax = fig.add_subplot(4, 1, 4) + ax.plot(px, scan_accel[pslc], "-", label="Acceleration") + ax.legend(loc="best") + ax.set_xlabel("Samples") + ax.set_ylabel("Scan Acceleration") else: - # Plot it all n_plot = obs.n_local_samples - - swplot = vel_switch[vel_switch <= n_plot] - bstable = begin_stable[begin_stable <= n_plot] - estable = end_stable[end_stable <= n_plot] - bthrow = begin_throw[begin_throw <= n_plot] - ethrow = end_throw[end_throw <= n_plot] - - fig = plt.figure(dpi=100, figsize=(8, 16)) - - ax = fig.add_subplot(4, 1, 1) - ax.plot( - np.arange(n_plot), - obs.shared[self.azimuth].data[:n_plot], - "-", - ) - ax.set_xlabel("Samples") - ax.set_ylabel("Azimuth") - - ax = fig.add_subplot(4, 1, 2) - ax.plot(np.arange(n_plot), stable[:n_plot], "-") - ax.vlines(bstable, ymin=-1, ymax=2, color="green") - ax.vlines(estable, ymin=-1, ymax=2, color="red") - ax.vlines(bthrow, ymin=-2, ymax=1, color="cyan") - ax.vlines(ethrow, ymin=-2, ymax=1, color="purple") - ax.set_xlabel("Samples") - ax.set_ylabel("Stable Scan / Throw") - - ax = fig.add_subplot(4, 1, 3) - ax.plot(np.arange(n_plot), scan_vel[:n_plot], "-") - ax.plot(np.arange(n_plot), wscan_vel[:n_plot], "-") - ax.vlines( - swplot, - ymin=np.amin(scan_vel), - ymax=np.amax(scan_vel), - ) - ax.set_xlabel("Samples") - ax.set_ylabel("Scan Velocity") - - ax = fig.add_subplot(4, 1, 4) - ax.plot(np.arange(n_plot), scan_accel[:n_plot], "-") - ax.set_xlabel("Samples") - ax.set_ylabel("Scan Acceleration") - + fig = plt.figure(dpi=100, figsize=(8, 12)) + + ax = fig.add_subplot(3, 1, 1) + ax.plot( + np.arange(n_plot), + azimuth[:n_plot], + "-", + ) + ax.set_xlabel("Samples") + ax.set_ylabel("Azimuth") + + ax = fig.add_subplot(3, 1, 2) + ax.plot(np.arange(n_plot), scan_vel[:n_plot], "-") + ax.vlines( + swplot, + ymin=np.amin(scan_vel), + ymax=np.amax(scan_vel), + ) + ax.set_xlabel("Samples") + ax.set_ylabel("Scan Velocity") + + ax = fig.add_subplot(3, 1, 3) + ax.plot(np.arange(n_plot), scan_accel[:n_plot], "-") + ax.set_xlabel("Samples") + ax.set_ylabel("Scan Acceleration") plt.savefig(out_file) plt.close() # Now create the intervals across each process column + if obs.comm_col is not None: + have_scanning = obs.comm_col.bcast(have_scanning, root=0) - # The throw intervals are between turnarounds - obs.intervals.create_col( - self.throw_interval, throw_times, stamps, fromrank=0 - ) - obs.intervals.create_col( - self.throw_leftright_interval, throw_leftright_times, stamps, fromrank=0 - ) - obs.intervals.create_col( - self.throw_rightleft_interval, throw_rightleft_times, stamps, fromrank=0 - ) - - # Stable scanning intervals - obs.intervals.create_col( - self.scanning_interval, stable_times, stamps, fromrank=0 - ) - obs.intervals.create_col( - self.scan_leftright_interval, stable_leftright_times, stamps, fromrank=0 - ) - obs.intervals.create_col( - self.scan_rightleft_interval, stable_rightleft_times, stamps, fromrank=0 - ) - - # Turnarounds are the inverse of stable scanning - obs.intervals[self.turnaround_interval] = ~obs.intervals[ - self.scanning_interval - ] + if have_scanning: + # The throw intervals are between turnarounds + obs.intervals.create_col( + self.throw_interval, throw_times, stamps, fromrank=0 + ) + obs.intervals.create_col( + self.throw_leftright_interval, + throw_leftright_times, + stamps, + fromrank=0, + ) + obs.intervals.create_col( + self.throw_rightleft_interval, + throw_rightleft_times, + stamps, + fromrank=0, + ) + + # Stable scanning intervals + obs.intervals.create_col( + self.scanning_interval, stable_times, stamps, fromrank=0 + ) + obs.intervals.create_col( + self.scan_leftright_interval, + stable_leftright_times, + stamps, + fromrank=0, + ) + obs.intervals.create_col( + self.scan_rightleft_interval, + stable_rightleft_times, + stamps, + fromrank=0, + ) + + # Turnarounds are the inverse of stable scanning + obs.intervals[self.turnaround_interval] = ~obs.intervals[ + self.scanning_interval + ] + else: + # Flag all samples as unstable + if self.shared_flags not in obs.shared: + obs.shared.create_column( + self.shared_flags, + shape=(obs.n_local_samples,), + dtype=np.uint8, + ) + if obs.comm_col_rank == 0: + obs.shared[self.shared_flags].set( + np.zeros_like(obs.shared[self.shared_flags].data), + offset=(0,), + fromrank=0, + ) + else: + obs.shared[self.shared_flags].set(None, offset=(0,), fromrank=0) # Additionally flag turnarounds as unstable pointing flag_intervals = FlagIntervals( @@ -390,6 +518,42 @@ def _exec(self, data, detectors=None, **kwargs): ) flag_intervals.apply(data, detectors=None) + def _find_turnaround(self, vel): + """Fit a polynomial and find the turnaround sample.""" + x = np.arange(len(vel)) + fit_poly = np.polynomial.polynomial.Polynomial.fit(x, vel, 5) + fit_vel = fit_poly(x) + vel_switch = np.where(fit_vel[:-1] * fit_vel[1:] < 0)[0] + if len(vel_switch) != 1: + return None + else: + return vel_switch[0] + + def _gradient(self, data, window, flags=None): + """Compute the numerical derivative with smoothing. + + Args: + data (array): The local data buffer to process. + window (int): The number of samples in the smoothing window. + flags (array): The optional array of sample flags. + + Returns: + (array): The result. + + """ + if flags is not None: + # Fill flags with noise + flagged_noise_fill(data, flags, window // 4, poly_order=5) + # Smooth the data + smoothed = uniform_filter1d( + data, + size=window, + mode="nearest", + ) + # Derivative + result = np.gradient(smoothed) + return result + def _finalize(self, data, **kwargs): return diff --git a/src/toast/ops/demodulation.py b/src/toast/ops/demodulation.py index 8a0c16d90..792698ad8 100644 --- a/src/toast/ops/demodulation.py +++ b/src/toast/ops/demodulation.py @@ -309,18 +309,13 @@ def _exec(self, data, detectors=None, **kwargs): self.demod_data.obs.append(demod_obs) if self.purge: - if self.shared_flags is not None: - del obs.shared[self.shared_flags] - for det_data in self.det_data.split(";"): - del obs.detdata[det_data] - if self.det_flags is not None: - del obs.detdata[self.det_flags] - if self.noise_model is not None: - del obs[self.noise_model] + obs.clear() log.debug_rank( "Demodulated observation in", comm=data.comm.comm_group, timer=timer ) + if self.purge: + data.clear() return diff --git a/src/toast/ops/fill_gaps.py b/src/toast/ops/fill_gaps.py new file mode 100644 index 000000000..930480543 --- /dev/null +++ b/src/toast/ops/fill_gaps.py @@ -0,0 +1,183 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. +import os + +import numpy as np +import traitlets +from astropy import units as u + +from ..observation import default_values as defaults +from ..timing import Timer, function_timer +from ..traits import Int, Quantity, Unicode, trait_docs +from ..utils import Environment, Logger, flagged_noise_fill +from .operator import Operator + + +@trait_docs +class FillGaps(Operator): + """Operator that fills flagged samples with noise. + + Currently this operator just fills flagged samples with a simple polynomial + plus white noise. It is mostly used for visualization. No attempt is made + yet to fill the gaps with a constrained noise realization. + + """ + + # Class traits + + API = Int(0, help="Internal interface version for this operator") + + times = Unicode(defaults.times, help="Observation shared key for timestamps") + + det_data = Unicode( + defaults.det_data, + help="Observation detdata key", + ) + + det_mask = Int( + defaults.det_mask_invalid, + help="Bit mask value for per-detector flagging", + ) + + shared_flags = Unicode( + defaults.shared_flags, + allow_none=True, + help="Observation shared key for telescope flags to use", + ) + + shared_flag_mask = Int( + defaults.shared_mask_invalid, + help="Bit mask value for optional shared flagging", + ) + + det_flags = Unicode( + defaults.det_flags, + allow_none=True, + help="Observation detdata key for flags to use", + ) + + det_flag_mask = Int( + defaults.det_mask_invalid, + help="Bit mask value for detector sample flagging", + ) + + buffer = Quantity( + 1.0 * u.s, + help="Buffer of time on either side of each gap", + ) + + poly_order = Int( + 1, + help="Order of the polynomial to fit across each gap", + ) + + @traitlets.validate("poly_order") + def _check_poly_order(self, proposal): + check = proposal["value"] + if check <= 0: + raise traitlets.TraitError("poly_order should be >= 1") + return check + + @traitlets.validate("det_mask") + def _check_det_mask(self, proposal): + check = proposal["value"] + if check < 0: + raise traitlets.TraitError("Det mask should be a positive integer") + return check + + @traitlets.validate("det_flag_mask") + def _check_det_flag_mask(self, proposal): + check = proposal["value"] + if check < 0: + raise traitlets.TraitError("Flag mask should be a positive integer") + return check + + @traitlets.validate("shared_flag_mask") + def _check_shared_flag_mask(self, proposal): + check = proposal["value"] + if check < 0: + raise traitlets.TraitError("Flag mask should be a positive integer") + return check + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @function_timer + def _exec(self, data, detectors=None, **kwargs): + env = Environment.get() + log = Logger.get() + + for ob in data.obs: + timer = Timer() + timer.start() + + # Sample rate for this observation + rate = ob.telescope.focalplane.sample_rate.to_value(u.Hz) + + # The buffer size in samples + buf_samp = int(self.buffer.to_value(u.second) * rate) + + # Check that parameters make sense + if self.poly_order > buf_samp + 1: + msg = f"Cannot fit an order {self.poly_order} polynomial " + msg += f"to {buf_samp} samples" + raise RuntimeError(msg) + + if buf_samp > ob.n_local_samples // 4: + msg = f"Using {buf_samp} samples of buffer around gaps is" + msg += f" not reasonable for an observation with {ob.n_local_samples}" + msg += " local samples" + raise RuntimeError(msg) + + # Local detectors we are considering + local_dets = ob.select_local_detectors(flagmask=self.det_mask) + n_dets = len(local_dets) + + # The shared flags + if self.shared_flags is None: + shared_flags = np.zeros(ob.n_local_samples, dtype=bool) + else: + shared_flags = ( + ob.shared[self.shared_flags].data & self.shared_flag_mask + ) != 0 + + for idet, det in enumerate(local_dets): + if self.det_flags is None: + flags = shared_flags + else: + flags = np.logical_or( + shared_flags, + (ob.detdata[self.det_flags][det, :] & self.det_flag_mask) != 0, + ) + flagged_noise_fill( + ob.detdata[self.det_data][det], + flags, + buf_samp, + poly_order=self.poly_order, + ) + msg = f"FillGaps {ob.name}: completed in" + log.debug_rank(msg, comm=data.comm.comm_group, timer=timer) + + def _finalize(self, data, **kwargs): + return + + def _requires(self): + # Note that the hwp_angle is not strictly required- this + # is just a no-op. + req = { + "shared": [self.times], + "detdata": [self.det_data], + } + if self.shared_flags is not None: + req["shared"].append(self.shared_flags) + if self.det_flags is not None: + req["detdata"].append(self.det_flags) + return req + + def _provides(self): + prov = { + "meta": [], + "detdata": [self.det_data], + } + return prov diff --git a/src/toast/ops/flag_intervals.py b/src/toast/ops/flag_intervals.py index 9c3a630eb..e65bb3ec4 100644 --- a/src/toast/ops/flag_intervals.py +++ b/src/toast/ops/flag_intervals.py @@ -105,10 +105,15 @@ def _exec(self, data, detectors=None, **kwargs): if ob.comm_col_rank == 0: new_flags = np.array(ob.shared[self.shared_flags]) for vname, vmask in self.view_mask: - views = ob.view[vname] - for vw in views: - # Note that a View acts like a slice - new_flags[vw] |= vmask + if vname in ob.intervals: + views = ob.view[vname] + for vw in views: + # Note that a View acts like a slice + new_flags[vw] |= vmask + else: + msg = f"Intervals '{vname}' does not exist in {ob.name}" + msg += " skipping flagging" + log.warning(msg) ob.shared[self.shared_flags].set(new_flags, offset=(0,), fromrank=0) def _finalize(self, data, **kwargs): diff --git a/src/toast/ops/hwpss_model.py b/src/toast/ops/hwpss_model.py index 67edf0446..717880837 100644 --- a/src/toast/ops/hwpss_model.py +++ b/src/toast/ops/hwpss_model.py @@ -19,7 +19,7 @@ from ..observation import default_values as defaults from ..timing import Timer, function_timer from ..traits import Bool, Int, Quantity, Unicode, Float, trait_docs -from ..utils import Environment, Logger +from ..utils import Environment, Logger, flagged_noise_fill from .operator import Operator @@ -126,7 +126,7 @@ class HWPSynchronousModel(Operator): time_drift = Bool(False, help="If True, include time drift terms in the model") - fill_gaps = Bool(False, help="If True, fit a simple line across gaps") + fill_gaps = Bool(False, help="If True, fill gaps with a simple noise model") debug = Unicode( None, @@ -334,7 +334,15 @@ def _exec(self, data, detectors=None, **kwargs): dc = np.mean(ob.detdata[self.det_data][det][good]) ob.detdata[self.det_data][det][good] -= dc if self.fill_gaps: - self._fill_gaps(ob, det, det_flags[det]) + rate = ob.telescope.focalplane.sample_rate.to_value(u.Hz) + # 1 second buffer + buffer = int(rate) + flagged_noise_fill( + ob.detdata[self.det_data][det], + det_flags[det], + buffer, + poly_order=1, + ) if self.relcal_continuous is not None: ob.detdata[self.relcal_continuous][det, :] = cal_center / det_mag @@ -925,33 +933,6 @@ def _stopped_flags(self, obs): stopped = np.array(unstable, dtype=np.uint8) return stopped - def _fill_gaps(self, obs, det, flags): - # Fill gaps with a line, just to kill large artifacts in flagged - # regions after removal of the HWPSS. This is mostly just for visualization. - # Downstream codes should ignore these flagged samples anyway. - sig = obs.detdata[self.det_data][det] - flag_indx = np.arange(len(flags), dtype=np.int64)[np.nonzero(flags)] - flag_groups = np.split(flag_indx, np.where(np.diff(flag_indx) != 1)[0] + 1) - for grp in flag_groups: - if len(grp) == 0: - continue - bad_first = grp[0] - bad_last = grp[-1] - if bad_first == 0: - # Starting bad samples - sig[: bad_last + 1] = sig[bad_last + 1] - elif bad_last == len(flags) - 1: - # Ending bad samples - sig[bad_first:] = sig[bad_first - 1] - else: - int_first = bad_first - 1 - int_last = bad_last + 1 - sig[bad_first : bad_last + 1] = np.interp( - np.arange(bad_first, bad_last + 1, 1, dtype=np.int32), - [int_first, int_last], - [sig[int_first], sig[int_last]], - ) - def _finalize(self, data, **kwargs): return diff --git a/src/toast/ops/simple_deglitch.py b/src/toast/ops/simple_deglitch.py index ae17cfda0..21761189c 100644 --- a/src/toast/ops/simple_deglitch.py +++ b/src/toast/ops/simple_deglitch.py @@ -17,7 +17,7 @@ from ..observation import default_values as defaults from ..timing import Timer, function_timer from ..traits import Bool, Float, Instance, Int, Quantity, Unicode, trait_docs -from ..utils import Environment, Logger, name_UID +from ..utils import Environment, Logger, name_UID, flagged_noise_fill from .operator import Operator @@ -48,7 +48,7 @@ class SimpleDeglitch(Operator): ) reset_det_flags = Bool( - True, + False, help="Replace existing detector flags", ) @@ -204,17 +204,15 @@ def _exec(self, data, detectors=None, **kwargs): continue bad_view = np.isnan(sig_view) det_flags[ind][bad_view] |= self.glitch_mask - if self.fill_gaps: - nbad = np.sum(bad_view) - corrected_signal = sig[ind].copy() - corrected_signal[bad_view] = trend[bad_view] - corrected_signal[bad_view] += np.random.randn(nbad) * rms - # DEBUG begin - # import pdb - # import matplotlib.pyplot as plt - # pdb.set_trace() - # DEBUG end - sig[ind] = corrected_signal + if self.fill_gaps: + # 1 second buffer + buffer = int(focalplane.sample_rate.to_value(u.Hz)) + flagged_noise_fill( + sig, + det_flags, + buffer, + poly_order=1, + ) return diff --git a/src/toast/ops/simple_jumpcorrect.py b/src/toast/ops/simple_jumpcorrect.py index 96d7468a6..26f765a3c 100644 --- a/src/toast/ops/simple_jumpcorrect.py +++ b/src/toast/ops/simple_jumpcorrect.py @@ -100,6 +100,18 @@ class SimpleJumpCorrect(Operator): "the detector and time stream will be flagged as invalid.", ) + save_jumps = Unicode( + None, + allow_none=True, + help="Save the jump corrections to a dictionary of values per observation", + ) + + apply_jumps = Unicode( + None, + allow_none=True, + help="Do not compute jumps, instead apply the specified dictionary of values", + ) + @traitlets.validate("det_mask") def _check_det_mask(self, proposal): check = proposal["value"] @@ -179,7 +191,7 @@ def _find_peaks(self, toi, flag, flag_out, lim=3.0, tol=1e4, sigma_in=None): # Only one jump per iteration # And skip remaining if find more than `njump_limit` jumps - while (npeak > 0) and (len(peaks) <= self.njump_limit) : + while (npeak > 0) and (len(peaks) <= self.njump_limit): imax = np.argmax(np.abs(mytoi)) amplitude = mytoi[imax] significance = np.abs(amplitude) / sigma @@ -212,7 +224,7 @@ def _get_sigma(self, toi, flag, tol): sigmas = [] nn = len(toi) # Ignore tol samples at the edge - for start in range(tol, nn - 3*tol + 1, 2*tol): + for start in range(tol, nn - 3 * tol + 1, 2 * tol): stop = start + 2 * tol ind = slice(start, stop) x = toi[ind][full_flag[ind] == 0] @@ -239,13 +251,17 @@ def _remove_jumps(self, signal, flag, peaks, tol): corrected_signal[peak:] -= amplitude pstart = max(0, peak - tol) pstop = min(nsample, peak + tol) - flag_out[pstart : pstop] = True + flag_out[pstart:pstop] = True return corrected_signal, flag_out @function_timer def _exec(self, data, detectors=None, **kwargs): log = Logger.get() + if self.save_jumps is not None and self.apply_jumps is not None: + msg = "Cannot both save to and apply pre-existing jumps" + raise RuntimeError(msg) + stepfilter = self._get_stepfilter(self.filterlen) for ob in data.obs: @@ -258,7 +274,11 @@ def _exec(self, data, detectors=None, **kwargs): local_dets = ob.select_local_detectors(flagmask=self.det_mask) shared_flags = ob.shared[self.shared_flags].data & self.shared_flag_mask + if self.save_jumps is not None: + jump_props = dict() for name in local_dets: + if self.save_jumps is not None: + jump_dets = list() sig = ob.detdata[self.det_data][name] det_flags = ob.detdata[self.det_flags][name] if self.reset_det_flags: @@ -267,35 +287,48 @@ def _exec(self, data, detectors=None, **kwargs): shared_flags != 0, (det_flags & self.det_flag_mask) != 0, ) - for iview, view in enumerate(views): - nsample = view.last - view.first + 1 - ind = slice(view.first, view.last + 1) - sig_view = sig[ind].copy() - bad_view = bad[ind] - bad_view_out = bad_view.copy() - sig_filtered = convolve(sig_view, stepfilter, mode="same") - peaks = self._find_peaks( - sig_filtered, - bad_view, - bad_view_out, - lim=self.jump_limit, - tol=self.filterlen // 2, - ) - - njump = len(peaks) - if njump == 0: - continue - if njump > self.njump_limit: - ob._detflags[name] |= self.det_mask - det_flags[ind] |= self.det_flag_mask - continue - + if self.apply_jumps is not None: corrected_signal, flag_out = self._remove_jumps( - sig_view, bad_view, peaks, self.jump_radius + sig, bad, ob[self.apply_jumps][name], self.jump_radius ) - sig[ind] = corrected_signal - det_flags[ind][flag_out] |= self.jump_mask - + sig[:] = corrected_signal + det_flags[flag_out] |= self.jump_mask + else: + for iview, view in enumerate(views): + nsample = view.last - view.first + 1 + ind = slice(view.first, view.last + 1) + sig_view = sig[ind].copy() + bad_view = bad[ind] + bad_view_out = bad_view.copy() + sig_filtered = convolve(sig_view, stepfilter, mode="same") + peaks = self._find_peaks( + sig_filtered, + bad_view, + bad_view_out, + lim=self.jump_limit, + tol=self.filterlen // 2, + ) + if self.save_jumps is not None: + jump_dets.extend( + [(x + view.first, y, z) for x, y, z in peaks] + ) + njump = len(peaks) + if njump == 0: + continue + if njump > self.njump_limit: + ob._detflags[name] |= self.det_mask + det_flags[ind] |= self.det_flag_mask + continue + + corrected_signal, flag_out = self._remove_jumps( + sig_view, bad_view, peaks, self.jump_radius + ) + sig[ind] = corrected_signal + det_flags[ind][flag_out] |= self.jump_mask + if self.save_jumps is not None: + jump_props[name] = jump_dets + if self.save_jumps is not None: + ob[self.save_jumps] = jump_props return def _finalize(self, data, **kwargs): diff --git a/src/toast/tests/CMakeLists.txt b/src/toast/tests/CMakeLists.txt index 3225342ab..54062fb58 100644 --- a/src/toast/tests/CMakeLists.txt +++ b/src/toast/tests/CMakeLists.txt @@ -76,8 +76,10 @@ install(FILES ops_yield_cut.py ops_elevation_noise.py ops_signal_diff_noise.py + ops_azimuth_intervals.py ops_loader.py accelerator.py ops_example_ground.py + ops_fill_gaps.py DESTINATION ${PYTHON_SITE}/toast/tests ) diff --git a/src/toast/tests/ops_azimuth_intervals.py b/src/toast/tests/ops_azimuth_intervals.py new file mode 100644 index 000000000..e32c47f3b --- /dev/null +++ b/src/toast/tests/ops_azimuth_intervals.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +from datetime import datetime + +import numpy as np +from astropy import units as u + +from .. import ops as ops +from ..data import Data +from ..instrument import Focalplane, GroundSite, Telescope +from ..instrument_sim import fake_hexagon_focalplane +from ..mpi import MPI, Comm +from ..observation import Observation +from ..observation import default_values as defaults +from ..pixels_io_healpix import write_healpix_fits +from ..schedule import GroundSchedule +from ..schedule_sim_ground import run_scheduler +from ..vis import set_matplotlib_backend +from ..ops.sim_ground_utils import scan_between + +from ._helpers import close_data, create_comm, create_outdir, create_ground_telescope +from .mpi import MPITestCase + + +class AzimuthIntervalsTest(MPITestCase): + def setUp(self): + fixture_name = os.path.splitext(os.path.basename(__file__))[0] + self.outdir = create_outdir(self.comm, fixture_name) + + def create_fake_data(self): + np.random.seed(123456) + # Just one group with all processes + toastcomm = create_comm(self.comm, single_group=True) + + rate = 100.0 * u.Hz + + telescope = create_ground_telescope( + toastcomm.group_size, + sample_rate=rate, + pixel_per_process=1, + fknee=None, + freqs=None, + width=5.0 * u.degree, + ) + + data = Data(toastcomm) + + # 8 minutes + n_samp = int(8 * 60 * rate.to_value(u.Hz)) + n_parked = int(0.1 * n_samp) + n_scan = int(0.12 * n_samp) + + ob = Observation(toastcomm, telescope, n_samples=n_samp, name="aztest") + # Create shared objects for timestamps, common flags, boresight, position, + # and velocity. + ob.shared.create_column( + defaults.times, + shape=(ob.n_local_samples,), + dtype=np.float64, + ) + ob.shared.create_column( + defaults.shared_flags, + shape=(ob.n_local_samples,), + dtype=np.uint8, + ) + ob.shared.create_column( + defaults.azimuth, + shape=(ob.n_local_samples,), + dtype=np.float64, + ) + ob.shared.create_column( + defaults.elevation, + shape=(ob.n_local_samples,), + dtype=np.float64, + ) + + # Rank zero of each grid column creates the data + stamps = None + azimuth = None + elevation = None + flags = None + scans = None + if ob.comm_col_rank == 0: + start_time = 0.0 + float(ob.local_index_offset) / rate.to_value(u.Hz) + stop_time = start_time + float(ob.n_local_samples - 1) / rate.to_value(u.Hz) + stamps = np.linspace( + start_time, + stop_time, + num=ob.n_local_samples, + endpoint=True, + dtype=np.float64, + ) + + scans = (n_samp - n_parked) // n_scan + sim_scans = scans + 1 + + azimuth = np.zeros(ob.n_local_samples, dtype=np.float64) + elevation = np.radians(45.0) * np.ones(ob.n_local_samples, dtype=np.float64) + + azimuth[:n_parked] = np.pi / 4 + + for iscan in range(sim_scans): + first_samp = iscan * n_scan + n_parked + if iscan % 2 == 0: + azstart = np.pi / 4 + azstop = 3 * np.pi / 4 + else: + azstart = 3 * np.pi / 4 + azstop = np.pi / 4 + _, az, el = scan_between( + stamps[first_samp], + azstart, + np.pi / 4, + azstop, + np.pi / 4, + np.radians(1.0), # rad / s + np.radians(0.25), # rad / s^2 + np.radians(1.0), # rad / s + np.radians(0.25), # rad / s^2 + nstep=n_scan, + ) + if iscan == scans: + azimuth[first_samp:] = az[: n_samp - first_samp] + elevation[first_samp:] = el[: n_samp - first_samp] + else: + azimuth[first_samp : first_samp + n_scan] = az + elevation[first_samp : first_samp + n_scan] = el + + # Add some noise + scale = 0.00005 + azimuth[:] += np.random.normal(loc=0, scale=scale, size=ob.n_local_samples) + elevation[:] += np.random.normal( + loc=0, scale=scale, size=ob.n_local_samples + ) + + # Periodic flagged samples. Add garbage spikes there. + flags = np.zeros(ob.n_local_samples, dtype=np.uint8) + for fspan in range(5): + flags[:: 1000 + fspan] = defaults.shared_mask_invalid + bad_samps = flags != 0 + azimuth[bad_samps] = 10.0 + elevation[bad_samps] = -10.0 + + if ob.comm_col is not None: + scans = ob.comm_col.bcast(scans, root=0) + + ob.shared[defaults.times].set(stamps, offset=(0,), fromrank=0) + ob.shared[defaults.azimuth].set(azimuth, offset=(0,), fromrank=0) + ob.shared[defaults.elevation].set(elevation, offset=(0,), fromrank=0) + ob.shared[defaults.shared_flags].set(flags, offset=(0,), fromrank=0) + data.obs.append(ob) + return data, scans + + def test_exec(self): + data, num_scans = self.create_fake_data() + + azint = ops.AzimuthIntervals( + debug_root=os.path.join(self.outdir, "az_intervals"), + window_seconds=5.0, + ) + azint.apply(data) + + for ob in data.obs: + n_scans = len(ob.intervals[defaults.scanning_interval]) + if n_scans != num_scans + 1: + msg = f"Found {n_scans} scanning intervals instead of {num_scans}" + print(msg, flush=True) + self.assertTrue(False) + + close_data(data) diff --git a/src/toast/tests/ops_demodulate.py b/src/toast/tests/ops_demodulate.py index 64372805c..d539072a7 100644 --- a/src/toast/tests/ops_demodulate.py +++ b/src/toast/tests/ops_demodulate.py @@ -113,7 +113,7 @@ def test_demodulate(self): # Demodulate downsample = 3 - demod = ops.Demodulate(stokes_weights=weights, nskip=downsample, purge=True) + demod = ops.Demodulate(stokes_weights=weights, nskip=downsample, purge=False) demod_data = demod.apply(data) # Map again diff --git a/src/toast/tests/ops_fill_gaps.py b/src/toast/tests/ops_fill_gaps.py new file mode 100644 index 000000000..96e433477 --- /dev/null +++ b/src/toast/tests/ops_fill_gaps.py @@ -0,0 +1,160 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os + +import numpy as np +from astropy import units as u + +from .. import ops as ops +from ..observation import default_values as defaults +from ._helpers import ( + close_data, + create_ground_data, + create_outdir, + fake_flags, +) +from .mpi import MPITestCase + + +class FillGapsTest(MPITestCase): + def setUp(self): + fixture_name = os.path.splitext(os.path.basename(__file__))[0] + self.outdir = create_outdir(self.comm, fixture_name) + np.random.seed(123456) + if ( + ("CONDA_BUILD" in os.environ) + or ("CIBUILDWHEEL" in os.environ) + or ("CI" in os.environ) + ): + self.make_plots = False + else: + self.make_plots = True + + def test_gap_fill(self): + # Create some test data. Disable HWPSS, since we are not demodulating + # in this example. + data, input_rms = self.create_test_data() + + # Make a copy for later comparison + ops.Copy(detdata=[(defaults.det_data, "input")]).apply(data) + + # Linear fit plus noise + filler = ops.FillGaps( + shared_flag_mask=defaults.shared_mask_nonscience, + buffer=1.0 * u.s, + poly_order=1, + ) + filler.apply(data) + + # Diagnostic plots of one detector on each process. + if self.make_plots: + import matplotlib.pyplot as plt + + for ob in data.obs: + det = ob.select_local_detectors(flagmask=defaults.det_mask_nonscience)[ + 0 + ] + n_all_samp = ob.n_all_samples + n_plot = 2 + fig_height = 6 * n_plot + pltsamp = 200 + + for first, last in [ + (0, n_all_samp), + (n_all_samp // 2 - pltsamp, n_all_samp // 2 + pltsamp), + ]: + plot_slc = slice(first, last, 1) + outfile = os.path.join( + self.outdir, + f"filled_{ob.name}_{det}_{first}-{last}.pdf", + ) + + times = ob.shared[defaults.times].data + samp_indx = np.arange(n_all_samp) + input = ob.detdata["input"][det] + signal = ob.detdata[defaults.det_data][det] + detflags = ob.detdata[defaults.det_flags][det] + shflags = ob.shared[defaults.shared_flags].data + + fig = plt.figure(figsize=(12, fig_height), dpi=72) + ax = fig.add_subplot(n_plot, 1, 1, aspect="auto") + # Plot signal + ax.plot( + samp_indx[plot_slc], + input[plot_slc], + color="black", + label=f"{det} Input", + ) + ax.plot( + samp_indx[plot_slc], + signal[plot_slc], + color="red", + label=f"{det} Filled", + ) + ax.legend(loc="best") + # Plot flags + ax = fig.add_subplot(n_plot, 1, 2, aspect="auto") + ax.plot( + samp_indx[plot_slc], + shflags[plot_slc], + color="blue", + label="Shared Flags", + ) + ax.plot( + samp_indx[plot_slc], + detflags[plot_slc], + color="red", + label=f"{det} Flags", + ) + ax.legend(loc="best") + fig.suptitle(f"Obs {ob.name}: {first} - {last}") + fig.savefig(outfile) + plt.close(fig) + + close_data(data) + + def create_test_data(self): + # Slightly slower than 0.5 Hz + hwp_rpm = 29.0 + hwp_rate = 2 * np.pi * hwp_rpm / 60.0 # rad/s + + sample_rate = 30 * u.Hz + ang_per_sample = hwp_rate / sample_rate.to_value(u.Hz) + + # Create a fake ground observations set for testing. + data = create_ground_data( + self.comm, + sample_rate=sample_rate, + hwp_rpm=hwp_rpm, + pixel_per_process=1, + single_group=True, + fp_width=5.0 * u.degree, + ) + + # Create an uncorrelated noise model from focalplane detector properties + default_model = ops.DefaultNoiseModel(noise_model="noise_model") + default_model.apply(data) + + # Simulate fake instrumental noise + sim_noise = ops.SimNoise(noise_model="noise_model") + sim_noise.apply(data) + + # Create flagged samples + fake_flags(data) + + # Now we will increase the noise amplitude of flagged samples to + # make it easier to check that we have filled gaps with something + # reasonable. + rms = dict() + for ob in data.obs: + for det in ob.local_detectors: + input = np.std(ob.detdata[defaults.det_data][det]) + rms[det] = input + flags = np.array(ob.shared[defaults.shared_flags].data) + flags[:] |= ob.detdata[defaults.det_flags][det, :] + bad = flags != 0 + ob.detdata[defaults.det_data][det, bad] *= 20 + + return data, rms diff --git a/src/toast/tests/runner.py b/src/toast/tests/runner.py index 9fec8663e..6c386f9bd 100644 --- a/src/toast/tests/runner.py +++ b/src/toast/tests/runner.py @@ -26,12 +26,14 @@ from . import math_misc as test_math_misc from . import noise as test_noise from . import observation as test_observation +from . import ops_azimuth_intervals as test_ops_azimuth_intervals from . import ops_cadence_map as test_ops_cadence_map from . import ops_common_mode_noise as test_ops_common_mode_noise from . import ops_crosslinking as test_ops_crosslinking from . import ops_demodulate as test_ops_demodulate from . import ops_elevation_noise as test_ops_elevation_noise from . import ops_example_ground as test_ops_example_ground +from . import ops_fill_gaps as test_ops_fill_gaps from . import ops_filterbin as test_ops_filterbin from . import ops_flag_sso as test_ops_flag_sso from . import ops_gainscrambler as test_ops_gainscrambler @@ -183,6 +185,7 @@ def test(name=None, verbosity=2): suite.addTest(loader.loadTestsFromModule(test_dist)) suite.addTest(loader.loadTestsFromModule(test_config)) + suite.addTest(loader.loadTestsFromModule(test_ops_azimuth_intervals)) suite.addTest(loader.loadTestsFromModule(test_ops_sim_satellite)) suite.addTest(loader.loadTestsFromModule(test_ops_sim_ground)) suite.addTest(loader.loadTestsFromModule(test_ops_memory_counter)) @@ -211,6 +214,7 @@ def test(name=None, verbosity=2): suite.addTest(loader.loadTestsFromModule(test_ops_gainscrambler)) suite.addTest(loader.loadTestsFromModule(test_ops_sim_gaindrifts)) suite.addTest(loader.loadTestsFromModule(test_ops_polyfilter)) + suite.addTest(loader.loadTestsFromModule(test_ops_fill_gaps)) suite.addTest(loader.loadTestsFromModule(test_ops_groundfilter)) suite.addTest(loader.loadTestsFromModule(test_ops_hwpfilter)) suite.addTest(loader.loadTestsFromModule(test_ops_hwpss_model)) diff --git a/src/toast/utils.py b/src/toast/utils.py index aab66a869..d212dfaff 100644 --- a/src/toast/utils.py +++ b/src/toast/utils.py @@ -831,3 +831,79 @@ def is_empty(self): if (len(v) > 0) and not all(x is None for x in v): return False return True + + +def flagged_noise_fill(data, flags, buffer, poly_order=1): + """Fill flagged samples with noise. + + This finds contiguous flagged samples and fills each gap with a polynomial + trend using nearby samples plus gaussian white noise. + + Args: + data (array): The local data buffer to process. + flags (array): The array of sample flags. + buffer (int): Number of samples to use on either side of flagged regions. + poly_order (int): The polynomial order to fit across the gap. + + Returns: + None + + """ + n_samp = len(data) + if len(flags) != n_samp: + msg = "Data and flag array lengths should be the same" + raise RuntimeError(msg) + + if buffer <= 0 or buffer > n_samp // 4: + msg = "The buffer size around flagged regions should be large enough" + msg += " to estimate nearby noise properties, but small enough to fit" + msg += " within the buffer" + raise RuntimeError(msg) + + flag_indx = np.arange(n_samp, dtype=np.int64)[flags != 0] + flag_groups = np.split(flag_indx, np.where(np.diff(flag_indx) != 1)[0] + 1) + nfgroup = len(flag_groups) + + # Merge groups that are closer than the buffer length + groups = list() + igrp = 0 + while igrp < nfgroup: + grp = flag_groups[igrp] + if len(grp) == 0: + igrp += 1 + continue + first = grp[0] + last = grp[-1] + 1 + while igrp + 1 < nfgroup and last + buffer > flag_groups[igrp + 1][0]: + igrp += 1 + last = flag_groups[igrp][-1] + 1 + groups.append((int(first), int(last))) + igrp += 1 + + for igrp, (bad_first, bad_last) in enumerate(groups): + full_first = bad_first - buffer + if full_first < 0: + full_first = 0 + full_last = bad_last + buffer + if full_last > n_samp: + full_last = n_samp + fit_n_samps = full_last - full_first + fit_samps = np.arange(fit_n_samps) + fit_flags = flags[full_first:full_last] + fit_good = fit_flags == 0 + fit_bad = np.logical_not(fit_good) + in_fit_x = fit_samps[fit_good] + in_fit_y = data[full_first:full_last][fit_good] + fit_poly = np.polynomial.polynomial.Polynomial.fit( + in_fit_x, in_fit_y, poly_order + ) + fit_curve = fit_poly(in_fit_x) + + rms = np.std(in_fit_y - fit_curve) + + # Fill the gaps with noise plus the fit polynomial + full_fit = fit_poly(fit_samps) + n_bad = np.count_nonzero(fit_bad) + data[full_first:full_last][fit_bad] = full_fit[fit_bad] + np.random.normal( + scale=rms, size=n_bad + )