diff --git a/scripts/slice_dataset.py b/scripts/slice_dataset.py index 9c7b18a..7ed2b84 100644 --- a/scripts/slice_dataset.py +++ b/scripts/slice_dataset.py @@ -75,7 +75,7 @@ None, help=( 'Comma delimited list of variables to drop. If empty, drop no' - ' variables.' + ' variables. List may include data variables or coords.' ), ) @@ -83,7 +83,7 @@ 'keep_variables', None, help=( - 'Comma delimited list of variables to keep. If empty, use' + 'Comma delimited list of data variables to keep. If empty, use' ' --drop_variables to determine which variables to keep' ), ) @@ -143,7 +143,7 @@ def main(argv: abc.Sequence[str]) -> None: ds, input_chunks = xbeam.open_zarr(INPUT_PATH.value) if DROP_VARIABLES.value: - ds = ds[[v for v in ds if v not in DROP_VARIABLES.value]] + ds = ds.drop_vars(DROP_VARIABLES.value) elif KEEP_VARIABLES.value: ds = ds[KEEP_VARIABLES.value] diff --git a/weatherbench2/evaluation.py b/weatherbench2/evaluation.py index c5c5e7b..9b4ab60 100644 --- a/weatherbench2/evaluation.py +++ b/weatherbench2/evaluation.py @@ -708,6 +708,7 @@ def _evaluate( forecast_pipeline |= 'TemporalMean' >> xbeam.Mean( dim='init_time' if self.data_config.by_init else 'time', fanout=self.fanout, + skipna=False, ) return forecast_pipeline diff --git a/weatherbench2/metrics.py b/weatherbench2/metrics.py index 5e41a9e..98bbb4d 100644 --- a/weatherbench2/metrics.py +++ b/weatherbench2/metrics.py @@ -114,6 +114,7 @@ def compute( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Evaluate this metric on datasets with full temporal coverages.""" if "time" in forecast.dims: @@ -124,7 +125,10 @@ def compute( raise ValueError( f"Forecast has neither valid_time or init_time dimension {forecast}" ) - return self.compute_chunk(forecast, truth, region=region).mean(avg_dim) + return self.compute_chunk(forecast, truth, region=region).mean( + avg_dim, + skipna=skipna, + ) def _spatial_average( @@ -553,9 +557,10 @@ def compute( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Evaluate this metric on datasets with full temporal coverages.""" - result = super().compute(forecast, truth, region=region) + result = super().compute(forecast, truth, region=region, skipna=skipna) return result.assign_attrs(ensemble_size=forecast[self.ensemble_dim].size) @@ -1406,8 +1411,14 @@ def _compute_chunk_impl( for threshold in threshold_seq: quantile = threshold.quantile threshold = threshold.compute(truth) - truth_probability = xr.where(truth > threshold, 1.0, 0.0) - forecast_probability = xr.where(forecast > threshold, 1.0, 0.0) + truth_probability = xr.where( + truth.isnull(), + np.nan, + xr.where(truth > threshold, 1.0, 0.0), + ) + forecast_probability = xr.where( + forecast.isnull(), np.nan, xr.where(forecast > threshold, 1.0, 0.0) + ) if debias: mse_of_probabilities = _debiased_ensemble_mean_mse( forecast_probability, diff --git a/weatherbench2/metrics_test.py b/weatherbench2/metrics_test.py index 3e30c96..8bfe911 100644 --- a/weatherbench2/metrics_test.py +++ b/weatherbench2/metrics_test.py @@ -932,6 +932,69 @@ def test_ensemble_brier_score(self, error, ens_delta, expected): result['2m_temperature'].values, expected_arr, rtol=1e-4 ) + def test_nan_propagates_to_output(self): + kwargs = { + 'variables_2d': ['2m_temperature'], + 'variables_3d': [], + 'time_start': '2022-01-01', + 'time_stop': '2022-01-03', + } + forecast = schema.mock_forecast_data( + ensemble_size=4, lead_stop='1 day', **kwargs + ) + forecast = ( + # Use settings from test_ensemble_brier_score that result in score=0. + forecast + + 1.0 + + 0.1 * np.arange(-2, 2).reshape((4, 1, 1, 1, 1)) + ) + truth = schema.mock_truth_data(**kwargs) + truth = truth + 1.0 + + forecast_with_nan = xr.where( + forecast.prediction_timedelta < forecast.prediction_timedelta[-1], + np.nan, + forecast, + ) + truth_with_nan = xr.where(truth.time < truth.time[-1], np.nan, truth) + + climatology_mean = truth.isel(time=0, drop=True).expand_dims(dayofyear=366) + climatology_std = ( + truth.isel(time=0, drop=True) + .expand_dims( + dayofyear=366, + ) + .rename({'2m_temperature': '2m_temperature_std'}) + ) + climatology = xr.merge([climatology_mean, climatology_std]) + threshold = thresholds.GaussianQuantileThreshold( + climatology=climatology, quantile=0.2 + ) + + with self.subTest('forecast has nan'): + # When forecast has nan in prediction_timedelta, only that timedelta will + # be NaN. + result = metrics.EnsembleBrierScore(threshold).compute( + forecast_with_nan, truth + ) + expected_arr = np.array([[np.nan, 0.0]]) + np.testing.assert_allclose( + result['2m_temperature'].values, + expected_arr, + ) + + with self.subTest('truth has nan'): + # When truth has nan, the final average over times means the entire + # score is NaN. + result = metrics.EnsembleBrierScore(threshold).compute( + forecast, truth_with_nan + ) + expected_arr = np.array([[np.nan, np.nan]]) + np.testing.assert_allclose( + result['2m_temperature'].values, + expected_arr, + ) + class DebiasedEnsembleBrierScoreTest(parameterized.TestCase):