From e379c031c7523d761d06f2f078a485e99eb18c50 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Tue, 8 Oct 2024 10:13:30 -0700 Subject: [PATCH] `--skipna` FLAG added to weatherbench scripts. To `evaluate.py`. Using this means NaN values are skipped in most metrics. * The exceptions are * `RankHistograms` ignores it (and logs if there are NaNs) * SEEPS already effectively uses skipna To other scripts * `compute_averages` * `compute_ensemble_mean` * `resample_in_time` The API convention is: * Public methods have `skipna=False` default * Private methods require `skipna`. Along the way, I updated a number of private methods to require all args. This lets typing ensure I've passed args down from public -> private. PiperOrigin-RevId: 683670091 --- scripts/compute_averages.py | 12 +- scripts/compute_ensemble_mean.py | 10 +- scripts/evaluate.py | 15 +- scripts/resample_in_time.py | 34 +++- weatherbench2/evaluation.py | 35 +++- weatherbench2/metrics.py | 277 ++++++++++++++++++++++++------- weatherbench2/metrics_test.py | 124 ++++++++++---- weatherbench2/test_utils.py | 15 ++ 8 files changed, 412 insertions(+), 110 deletions(-) diff --git a/scripts/compute_averages.py b/scripts/compute_averages.py index 9ca2460..162777e 100644 --- a/scripts/compute_averages.py +++ b/scripts/compute_averages.py @@ -81,6 +81,14 @@ 'If empty, compute on all data_vars of --input_path' ), ) +SKIPNA = flags.DEFINE_boolean( + 'skipna', + False, + help=( + 'Whether to skip NaN data points (in forecasts and observations) when' + ' evaluating.' + ), +) FANOUT = flags.DEFINE_integer( 'fanout', None, @@ -138,7 +146,9 @@ def main(argv: list[str]): ( chunked - | xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value) + | xbeam.Mean( + AVERAGING_DIMS.value, skipna=SKIPNA.value, fanout=FANOUT.value + ) | xbeam.ChunksToZarr( OUTPUT_PATH.value, template, diff --git a/scripts/compute_ensemble_mean.py b/scripts/compute_ensemble_mean.py index ba3085a..b89f23b 100644 --- a/scripts/compute_ensemble_mean.py +++ b/scripts/compute_ensemble_mean.py @@ -76,6 +76,14 @@ ' all variables are selected.' ), ) +SKIPNA = flags.DEFINE_boolean( + 'skipna', + False, + help=( + 'Whether to skip NaN data points (in forecasts and observations) when' + ' evaluating.' + ), +) # pylint: disable=expression-not-assigned @@ -123,7 +131,7 @@ def main(argv: list[str]): split_vars=True, num_threads=NUM_THREADS.value, ) - | xbeam.Mean(REALIZATION_NAME.value, skipna=False) + | xbeam.Mean(REALIZATION_NAME.value, skipna=SKIPNA.value) | xbeam.ChunksToZarr( OUTPUT_PATH.value, template, diff --git a/scripts/evaluate.py b/scripts/evaluate.py index fd1dcbf..c0e0331 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -157,6 +157,14 @@ ' "2m_temperature"}' ), ) +SKIPNA = flags.DEFINE_boolean( + 'skipna', + False, + help=( + 'Whether to skip NaN data points (in forecasts and observations) when' + ' evaluating.' + ), +) PRESSURE_LEVEL_SUFFIXES = flags.DEFINE_bool( 'pressure_level_suffixes', False, @@ -630,14 +638,17 @@ def main(argv: list[str]) -> None: eval_configs, runner=RUNNER.value, input_chunks=INPUT_CHUNKS.value, + skipna=SKIPNA.value, fanout=FANOUT.value, num_threads=NUM_THREADS.value, argv=argv, ) else: - evaluation.evaluate_in_memory(data_config, eval_configs) + evaluation.evaluate_in_memory( + data_config, eval_configs, skipna=SKIPNA.value + ) if __name__ == '__main__': app.run(main) - flags.mark_flag_as_required('output_path') + flags.mark_flags_as_required(['output_path', 'obs_path']) diff --git a/scripts/resample_in_time.py b/scripts/resample_in_time.py index c541994..bbb7a19 100644 --- a/scripts/resample_in_time.py +++ b/scripts/resample_in_time.py @@ -142,6 +142,14 @@ ' use the last time in --input_path.' ), ) +SKIPNA = flags.DEFINE_boolean( + 'skipna', + False, + help=( + 'Whether to skip NaN data points (in forecasts and observations) when' + ' evaluating.' + ), +) WORKING_CHUNKS = flag_utils.DEFINE_chunks( 'working_chunks', '', @@ -182,6 +190,7 @@ def resample_in_time_chunk( min_vars: list[str], max_vars: list[str], add_mean_suffix: bool, + skipna: bool = False, ) -> tuple[xbeam.Key, xr.Dataset]: """Resample a data chunk in time and return a requested time statistic. @@ -196,6 +205,8 @@ def resample_in_time_chunk( max_vars: Variables to compute the max of. add_mean_suffix: Whether to add a "_mean" suffix to variables after computing the mean. + skipna: Whether to skip NaN values in both forecasts and observations during + evaluation. Returns: The resampled data chunk and its key. @@ -207,21 +218,23 @@ def resample_in_time_chunk( for chunk_var in chunk.data_vars: if chunk_var in mean_vars: rsmp_chunks.append( - resample_in_time_core(chunk, method, period, 'mean').rename( + resample_in_time_core( + chunk, method, period, 'mean', skipna=skipna + ).rename( {chunk_var: f'{chunk_var}_mean' if add_mean_suffix else chunk_var} ) ) if chunk_var in min_vars: rsmp_chunks.append( - resample_in_time_core(chunk, method, period, 'min').rename( - {chunk_var: f'{chunk_var}_min'} - ) + resample_in_time_core( + chunk, method, period, 'min', skipna=skipna + ).rename({chunk_var: f'{chunk_var}_min'}) ) if chunk_var in max_vars: rsmp_chunks.append( - resample_in_time_core(chunk, method, period, 'max').rename( - {chunk_var: f'{chunk_var}_max'} - ) + resample_in_time_core( + chunk, method, period, 'max', skipna=skipna + ).rename({chunk_var: f'{chunk_var}_max'}) ) return rsmp_key, xr.merge(rsmp_chunks) @@ -232,6 +245,7 @@ def resample_in_time_core( method: str, period: pd.Timedelta, statistic: str, + skipna: bool, ) -> t.Union[xr.Dataset, xr.DataArray]: """Core call to xarray resample or rolling.""" if method == 'rolling': @@ -245,12 +259,12 @@ def resample_in_time_core( {TIME_DIM.value: period // delta_t}, center=False, min_periods=None ), statistic, - )(skipna=False) + )(skipna=skipna) elif method == 'resample': return getattr( chunk.resample({TIME_DIM.value: period}, label='left'), statistic, - )(skipna=False) + )(skipna=skipna) else: raise ValueError(f'Unhandled {method=}') @@ -301,6 +315,7 @@ def main(argv: abc.Sequence[str]) -> None: METHOD.value, period, statistic='mean', + skipna=SKIPNA.value, )[TIME_DIM.value] else: rsmp_times = ds[TIME_DIM.value] @@ -369,6 +384,7 @@ def main(argv: abc.Sequence[str]) -> None: min_vars=min_vars, max_vars=max_vars, add_mean_suffix=ADD_MEAN_SUFFIX.value, + skipna=SKIPNA.value, ) ) | 'RechunkToOutputChunks' diff --git a/weatherbench2/evaluation.py b/weatherbench2/evaluation.py index 9bfbba7..942e414 100644 --- a/weatherbench2/evaluation.py +++ b/weatherbench2/evaluation.py @@ -389,6 +389,7 @@ def _metric_and_region_loop( forecast: xr.Dataset, truth: xr.Dataset, eval_config: config.Eval, + skipna: bool, compute_chunk: bool = False, ) -> xr.Dataset: """Compute metric results looping over metrics and regions in eval config.""" @@ -415,16 +416,18 @@ def _metric_and_region_loop( region_dim = xr.DataArray( [region_name], coords={'region': [region_name]} ) - tmp_result = eval_fn(forecast=forecast, truth=truth, region=region) + tmp_result = eval_fn( + forecast=forecast, truth=truth, region=region, skipna=skipna + ) tmp_results.append( tmp_result.expand_dims({'metric': metric_dim, 'region': region_dim}) ) logging.info(f'Logging region done: {region_name}') result = xr.concat(tmp_results, 'region') else: - result = eval_fn(forecast=forecast, truth=truth).expand_dims( - {'metric': metric_dim} - ) + result = eval_fn( + forecast=forecast, truth=truth, skipna=skipna + ).expand_dims({'metric': metric_dim}) results.append(result) logging.info(f'Logging metric done: {name}') results = xr.merge(results) @@ -435,6 +438,7 @@ def _evaluate_all_metrics( eval_name: str, eval_config: config.Eval, data_config: config.Data, + skipna: bool, ) -> None: """Evaluate a set of eval metrics in memory.""" forecast, truth, climatology = open_forecast_and_truth_datasets( @@ -466,7 +470,7 @@ def _evaluate_all_metrics( if data_config.by_init: truth = truth.sel(time=forecast.valid_time) - results = _metric_and_region_loop(forecast, truth, eval_config) + results = _metric_and_region_loop(forecast, truth, eval_config, skipna=skipna) logging.info(f'Logging Evaluation complete:\n{results}') @@ -478,6 +482,7 @@ def _evaluate_all_metrics( def evaluate_in_memory( data_config: config.Data, eval_configs: dict[str, config.Eval], + skipna: bool = False, ) -> None: """Run evaluation in memory. @@ -501,9 +506,11 @@ def evaluate_in_memory( Args: data_config: config.Data instance. eval_configs: Dictionary of config.Eval instances. + skipna: Whether to skip NaN values in both forecasts and observations during + evaluation. """ for eval_name, eval_config in eval_configs.items(): - _evaluate_all_metrics(eval_name, eval_config, data_config) + _evaluate_all_metrics(eval_name, eval_config, data_config, skipna=skipna) @dataclasses.dataclass @@ -550,13 +557,17 @@ class _EvaluateAllMetrics(beam.PTransform): eval_config: config.Eval instance. data_config: config.Data instance. input_chunks: Chunks to use for input files. + skipna: Whether to skip NaN values in both forecasts and observations during + evaluation. fanout: Fanout parameter for Beam combiners. + num_threads: Number of threads for reading/writing files. """ eval_name: str eval_config: config.Eval data_config: config.Data input_chunks: abc.Mapping[str, int] + skipna: bool fanout: Optional[int] = None num_threads: Optional[int] = None @@ -568,7 +579,11 @@ def _evaluate_chunk( forecast, truth = forecast_and_truth logging.info(f'Logging _evaluate_chunk Key: {key}') results = _metric_and_region_loop( - forecast, truth, self.eval_config, compute_chunk=True + forecast, + truth, + self.eval_config, + compute_chunk=True, + skipna=self.skipna, ) dropped_dims = [dim for dim in key.offsets if dim not in results.dims] result_key = key.with_offsets(**{dim: None for dim in dropped_dims}) @@ -712,7 +727,7 @@ def _evaluate( forecast_pipeline |= 'TemporalMean' >> xbeam.Mean( dim='init_time' if self.data_config.by_init else 'time', fanout=self.fanout, - skipna=False, + skipna=self.skipna, ) return forecast_pipeline @@ -736,6 +751,7 @@ def evaluate_with_beam( fanout: Optional[int] = None, num_threads: Optional[int] = None, argv: Optional[list[str]] = None, + skipna: bool = False, ) -> None: """Run evaluation with a Beam pipeline. @@ -764,6 +780,8 @@ def evaluate_with_beam( fanout: Beam CombineFn fanout. num_threads: Number of threads to use for reading/writing data. argv: Other arguments to pass into the Beam pipeline. + skipna: Whether to skip NaN values in both forecasts and observations during + evaluation. """ with beam.Pipeline(runner=runner, argv=argv) as root: @@ -779,6 +797,7 @@ def evaluate_with_beam( input_chunks, fanout=fanout, num_threads=num_threads, + skipna=skipna, ) | f'save_{eval_name}' >> _SaveOutputs( diff --git a/weatherbench2/metrics.py b/weatherbench2/metrics.py index 7c83bda..b49b2ec 100644 --- a/weatherbench2/metrics.py +++ b/weatherbench2/metrics.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import dataclasses import functools +import logging import typing as t import numpy as np @@ -89,6 +90,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Evaluate this metric on a temporal chunk of data. @@ -103,6 +105,8 @@ def compute_chunk( forecast. region: Region class. .apply() method is called inside before spatial averaging. + skipna: Whether to skip NaN values in both forecasts and observations + during evaluation. Returns: Dataset with metric results for each variable in forecasts/truth, without @@ -126,14 +130,18 @@ def compute( raise ValueError( f"Forecast has neither valid_time or init_time dimension {forecast}" ) - return self.compute_chunk(forecast, truth, region=region).mean( + return self.compute_chunk( + forecast, truth, region=region, skipna=skipna + ).mean( avg_dim, skipna=skipna, ) def _spatial_average( - dataset: xr.Dataset, region: t.Optional[Region] = None, skipna: bool = False + dataset: xr.Dataset, + region: t.Optional[Region], + skipna: bool, ) -> xr.Dataset: """Compute spatial average after applying region mask. @@ -156,7 +164,9 @@ def _spatial_average( def _spatial_average_l2_norm( - dataset: xr.Dataset, region: t.Optional[Region] = None, skipna: bool = False + dataset: xr.Dataset, + region: t.Optional[Region], + skipna: bool, ) -> xr.Dataset: """Helper function to compute sqrt(spatial_average(ds**2)).""" return np.sqrt(_spatial_average(dataset**2, region=region, skipna=skipna)) @@ -181,11 +191,13 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: diff = forecast - truth result = _spatial_average( diff[self.u_name] ** 2 + diff[self.v_name] ** 2, region=region, + skipna=skipna, ) return result @@ -213,10 +225,11 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: mse = WindVectorMSE( u_name=self.u_name, v_name=self.v_name, vector_name=self.vector_name - ).compute_chunk(forecast, truth, region=region) + ).compute_chunk(forecast, truth, region=region, skipna=skipna) return np.sqrt(mse) @@ -240,12 +253,18 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: - results = _spatial_average_l2_norm(forecast - truth, region=region) + results = _spatial_average_l2_norm( + forecast - truth, region=region, skipna=skipna + ) if self.wind_vector_rmse is not None: for wv in self.wind_vector_rmse: results[wv.vector_name] = wv.compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) return results @@ -266,12 +285,18 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: - results = _spatial_average((forecast - truth) ** 2, region=region) + results = _spatial_average( + (forecast - truth) ** 2, region=region, skipna=skipna + ) if self.wind_vector_mse is not None: for wv in self.wind_vector_mse: results[wv.vector_name] = wv.compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) return results @@ -285,7 +310,9 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: + del skipna # Ignored return (forecast - truth) ** 2 @@ -298,8 +325,9 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: - return _spatial_average(abs(forecast - truth), region=region) + return _spatial_average(abs(forecast - truth), region=region, skipna=skipna) @dataclasses.dataclass @@ -311,7 +339,9 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: + del skipna # Ignored return abs(forecast - truth) @@ -324,8 +354,9 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: - return _spatial_average(forecast - truth, region=region) + return _spatial_average(forecast - truth, region=region, skipna=skipna) @dataclasses.dataclass @@ -337,7 +368,9 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: + del skipna # Ignored return forecast - truth @@ -356,6 +389,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: if "init_time" in forecast.dims: time_dim = "valid_time" @@ -371,10 +405,12 @@ def compute_chunk( forecast_anom = forecast - climatology_chunk truth_anom = truth - climatology_chunk return _spatial_average( - forecast_anom * truth_anom, region=region + forecast_anom * truth_anom, + region=region, + skipna=skipna, ) / np.sqrt( - _spatial_average(forecast_anom**2, region=region) - * _spatial_average(truth_anom**2, region=region) + _spatial_average(forecast_anom**2, region=region, skipna=skipna) + * _spatial_average(truth_anom**2, region=region, skipna=skipna) ) @@ -435,7 +471,9 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: + del skipna # Ignored, must be effectively True because of p1 mask. forecast_cat = self._convert_precip_to_seeps_cat(forecast) truth_cat = self._convert_precip_to_seeps_cat(truth) @@ -479,9 +517,10 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: + del skipna # Ignored, must be effectively True because of p1 mask. result = super().compute_chunk(forecast, truth, region) - # Need skipna = True because of p1 mask return _spatial_average(result, region=region, skipna=True) @@ -494,6 +533,7 @@ def _debiased_ensemble_mean_mse( forecast: xr.Dataset, truth: xr.Dataset, ensemble_dim: str, + skipna: bool, ) -> xr.Dataset: """Debiased estimate of E(forecast.mean() - truth)². @@ -513,12 +553,14 @@ def _debiased_ensemble_mean_mse( forecast: A forecast dataset. truth: A ground truth dataset. ensemble_dim: Dimension indexing ensembles in the forecast. + skipna: Whether to skip NaN values in both forecasts and observations during + evaluation. Returns: Dataset with debiased (forecast - truth)². """ - forecast_mean = forecast.mean(ensemble_dim, skipna=False) - forecast_var = forecast.var(ensemble_dim, skipna=False, ddof=1) + forecast_mean = forecast.mean(ensemble_dim, skipna=skipna) + forecast_var = forecast.var(ensemble_dim, skipna=skipna, ddof=1) biased_mse = (truth - forecast_mean) ** 2 return biased_mse - forecast_var / _get_n_ensemble(forecast, ensemble_dim) @@ -617,12 +659,19 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """CRPS, averaged over space, for a time chunk of data.""" return CRPSSkill(self.ensemble_dim).compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) - 0.5 * CRPSSpread(self.ensemble_dim).compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) @@ -635,11 +684,13 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """CRPSSpread, averaged over space, for a time chunk of data.""" return _spatial_average( - _pointwise_crps_spread(forecast, self.ensemble_dim), + _pointwise_crps_spread(forecast, self.ensemble_dim, skipna=skipna), region=region, + skipna=skipna, ) @@ -652,11 +703,15 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """CRPSSkill, averaged over space, for a time chunk of data.""" return _spatial_average( - _pointwise_crps_skill(forecast, truth, self.ensemble_dim), + _pointwise_crps_skill( + forecast, truth, self.ensemble_dim, skipna=skipna + ), region=region, + skipna=skipna, ) @@ -669,12 +724,19 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """CRPS, averaged over space, for a time chunk of data.""" return SpatialCRPSSkill(self.ensemble_dim).compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) - 0.5 * SpatialCRPSSpread(self.ensemble_dim).compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) @@ -687,9 +749,10 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """CRPSSpread, averaged over space, for a time chunk of data.""" - return _pointwise_crps_spread(forecast, self.ensemble_dim) + return _pointwise_crps_spread(forecast, self.ensemble_dim, skipna=skipna) @dataclasses.dataclass @@ -701,9 +764,12 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """CRPSSkill, averaged over space, for a time chunk of data.""" - return _pointwise_crps_skill(forecast, truth, self.ensemble_dim) + return _pointwise_crps_skill( + forecast, truth, self.ensemble_dim, skipna=skipna + ) @utils.dataset_safe_lru_cache( @@ -713,7 +779,9 @@ def compute_chunk( maxsize=1, ) def _pointwise_crps_spread( - forecast: xr.Dataset, ensemble_dim: str + forecast: xr.Dataset, + ensemble_dim: str, + skipna: bool, ) -> xr.Dataset: """CRPS spread at each point in truth, averaged over ensemble only.""" n_ensemble = _get_n_ensemble(forecast, ensemble_dim) @@ -738,7 +806,7 @@ def _pointwise_crps_spread( 2 * ( ((2 * rank - n_ensemble - 1) * forecast).mean( - ensemble_dim, skipna=False + ensemble_dim, skipna=skipna ) ) / (n_ensemble - 1) @@ -746,11 +814,14 @@ def _pointwise_crps_spread( def _pointwise_crps_skill( - forecast: xr.Dataset, truth: xr.Dataset, ensemble_dim: str + forecast: xr.Dataset, + truth: xr.Dataset, + ensemble_dim: str, + skipna: bool, ) -> xr.Dataset: """CRPS skill at each point in truth, averaged over ensemble only.""" _get_n_ensemble(forecast, ensemble_dim) # Will raise if no ensembles. - return abs(truth - forecast).mean(ensemble_dim, skipna=False) + return abs(truth - forecast).mean(ensemble_dim, skipna=skipna) def _rank_ds(ds: xr.Dataset, dim: str) -> xr.Dataset: @@ -784,11 +855,13 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """GaussianCRPS, averaged over space, for a time chunk of data.""" return _spatial_average( _pointwise_gaussian_crps(forecast, truth), region=region, + skipna=skipna, ) @@ -844,6 +917,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """GaussianVariance, averaged over space, for a time chunk of data.""" del truth # unused @@ -859,6 +933,7 @@ def compute_chunk( return _spatial_average( xr.Dataset(dataset, coords=forecast.coords), region=region, + skipna=skipna, ) @@ -892,6 +967,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: if isinstance(self.threshold, thresholds.Threshold): @@ -927,7 +1003,9 @@ def compute_chunk( brier_scores.append( _spatial_average( - (forecast_probability - truth_probability) ** 2, region=region + (forecast_probability - truth_probability) ** 2, + region=region, + skipna=skipna, ).expand_dims(dim={"quantile": [quantile]}) ) @@ -963,6 +1041,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: if isinstance(self.threshold, thresholds.Threshold): @@ -1000,9 +1079,9 @@ def compute_chunk( ) ignorance_scores.append( - _spatial_average(ignorance_score, region=region).expand_dims( - dim={"quantile": [quantile]} - ) + _spatial_average( + ignorance_score, region=region, skipna=skipna + ).expand_dims(dim={"quantile": [quantile]}) ) return xr.merge(ignorance_scores).assign_attrs( @@ -1038,6 +1117,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: var_list = [] @@ -1062,7 +1142,9 @@ def compute_chunk( forecast_cdf = xr.Dataset(cdf_values, coords=forecast.coords) rps_per_threshold.append((forecast_cdf - truth_ecdf) ** 2) - return _spatial_average(sum(rps_per_threshold), region=region) + return _spatial_average( + sum(rps_per_threshold), region=region, skipna=skipna + ) @dataclasses.dataclass @@ -1094,6 +1176,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Ensemble Stddev, averaged over space, for a time chunk of data.""" del truth # unused @@ -1103,13 +1186,16 @@ def compute_chunk( return xr.zeros_like( # Compute the average, even though we return zeros_like. Why? Because, # this will preserve the scalar values of lat/lon coords correctly. - _spatial_average(forecast, region=region).mean( - self.ensemble_dim, skipna=False + _spatial_average(forecast, region=region, skipna=skipna).mean( + self.ensemble_dim, + skipna=skipna, ) ) else: return _spatial_average_l2_norm( - forecast.std(self.ensemble_dim, ddof=1, skipna=False), region=region + forecast.std(self.ensemble_dim, ddof=1, skipna=skipna), + region=region, + skipna=skipna, ) @@ -1122,6 +1208,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """EnsembleVariance, averaged over space, for a time chunk of data.""" del truth # unused @@ -1131,13 +1218,15 @@ def compute_chunk( return xr.zeros_like( # Compute the average, even though we return zeros_like. Why? Because, # this will preserve the scalar values of lat/lon coords correctly. - _spatial_average(forecast, region=region).mean( - self.ensemble_dim, skipna=False + _spatial_average(forecast, region=region, skipna=skipna).mean( + self.ensemble_dim, skipna=skipna ) ) else: return _spatial_average( - forecast.var(self.ensemble_dim, ddof=1, skipna=False), region=region + forecast.var(self.ensemble_dim, ddof=1, skipna=skipna), + region=region, + skipna=skipna, ) @@ -1150,6 +1239,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Ensemble variance, for a time chunk of data.""" del truth # unused @@ -1160,9 +1250,9 @@ def compute_chunk( # Compute the average, even though we return zeros_like. Why? Because, # this will preserve the scalar values of lat/lon coords correctly. forecast - ).mean(self.ensemble_dim, skipna=False) + ).mean(self.ensemble_dim, skipna=skipna) else: - return forecast.var(self.ensemble_dim, ddof=1, skipna=False) + return forecast.var(self.ensemble_dim, ddof=1, skipna=skipna) @dataclasses.dataclass @@ -1194,12 +1284,15 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """EnsembleMeanRMSE, averaged over space, for a time chunk of data.""" _get_n_ensemble(forecast, self.ensemble_dim) # Will raise if no ensembles. return _spatial_average_l2_norm( - truth - forecast.mean(self.ensemble_dim, skipna=False), region=region + truth - forecast.mean(self.ensemble_dim, skipna=skipna), + region=region, + skipna=skipna, ) @@ -1217,13 +1310,15 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """EnsembleMeanRMSE, averaged over space, for a time chunk of data.""" _get_n_ensemble(forecast, self.ensemble_dim) # Will raise if no ensembles. return _spatial_average( - (truth - forecast.mean(self.ensemble_dim, skipna=False)) ** 2, + (truth - forecast.mean(self.ensemble_dim, skipna=skipna)) ** 2, region=region, + skipna=skipna, ) @@ -1243,13 +1338,17 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """DebiasedEnsembleMeanMSE, averaged over space, for one time chunk.""" _get_n_ensemble(forecast, self.ensemble_dim) # Will raise if no ensembles. return _spatial_average( - _debiased_ensemble_mean_mse(forecast, truth, self.ensemble_dim), + _debiased_ensemble_mean_mse( + forecast, truth, self.ensemble_dim, skipna=skipna + ), region=region, + skipna=skipna, ) @@ -1262,11 +1361,12 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Squared error in the ensemble mean, for a time chunk of data.""" _get_n_ensemble(forecast, self.ensemble_dim) # Will raise if no ensembles. - return (truth - forecast.mean(self.ensemble_dim, skipna=False)) ** 2 + return (truth - forecast.mean(self.ensemble_dim, skipna=skipna)) ** 2 @dataclasses.dataclass @@ -1278,11 +1378,14 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Squared error in the ensemble mean, for a time chunk of data.""" _get_n_ensemble(forecast, self.ensemble_dim) # Will raise if no ensembles. - return _debiased_ensemble_mean_mse(forecast, truth, self.ensemble_dim) + return _debiased_ensemble_mean_mse( + forecast, truth, self.ensemble_dim, skipna=skipna + ) @dataclasses.dataclass @@ -1334,12 +1437,19 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Energy score, averaged over space, for a time chunk of data.""" return EnergyScoreSkill(self.ensemble_dim).compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) - 0.5 * EnergyScoreSpread(self.ensemble_dim).compute_chunk( - forecast, truth, region=region + forecast, + truth, + region=region, + skipna=skipna, ) @@ -1352,6 +1462,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Energy score spread, averaged over space, for a time chunk of data.""" n_ensemble = _get_n_ensemble(forecast, self.ensemble_dim) @@ -1360,15 +1471,18 @@ def compute_chunk( return xr.zeros_like( # Compute the average, even though we return zeros_like. Why? Because, # this will preserve the scalar values of lat/lon coords correctly. - _spatial_average(forecast, region=region).mean( - self.ensemble_dim, skipna=False + _spatial_average(forecast, region=region, skipna=skipna).mean( + self.ensemble_dim, + skipna=skipna, ) ) else: return _spatial_average_l2_norm( self._ensemble_slice(forecast, slice(None, -1)) - self._ensemble_slice(forecast, slice(1, None)), - ).mean(self.ensemble_dim, skipna=False) + region=region, + skipna=skipna, + ).mean(self.ensemble_dim, skipna=skipna) @dataclasses.dataclass @@ -1380,11 +1494,15 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Energy score skill, averaged over space, for a time chunk of data.""" _get_n_ensemble(forecast, self.ensemble_dim) # Will raise if no ensembles. - return _spatial_average_l2_norm(forecast - truth).mean( - self.ensemble_dim, skipna=False + return _spatial_average_l2_norm( + forecast - truth, region=region, skipna=skipna + ).mean( + self.ensemble_dim, + skipna=skipna, ) @@ -1415,7 +1533,8 @@ def _compute_chunk_impl( debias: bool, forecast: xr.Dataset, truth: xr.Dataset, - region: t.Optional[Region] = None, + region: t.Optional[Region], + skipna: bool, ) -> xr.Dataset: """Common implementation of compute_chunk.""" @@ -1430,23 +1549,29 @@ def _compute_chunk_impl( for threshold in threshold_seq: quantile = threshold.quantile threshold = threshold.compute(truth) + # Notice we allow NaN in truth/forecast probabilities, then skipna during + # computation of BrierScore (which is really just an MSE over the + # probabilities). truth_probability = xr.where( truth.isnull(), np.nan, xr.where(truth > threshold, 1.0, 0.0), ) forecast_probability = xr.where( - forecast.isnull(), np.nan, xr.where(forecast > threshold, 1.0, 0.0) + forecast.isnull(), + np.nan, + xr.where(forecast > threshold, 1.0, 0.0), ) if debias: mse_of_probabilities = _debiased_ensemble_mean_mse( forecast_probability, truth_probability, self.ensemble_dim, + skipna=skipna, ) else: mse_of_probabilities = ( - forecast_probability.mean(self.ensemble_dim, skipna=False) + forecast_probability.mean(self.ensemble_dim, skipna=skipna) - truth_probability ) ** 2 @@ -1454,6 +1579,7 @@ def _compute_chunk_impl( _spatial_average( mse_of_probabilities, region=region, + skipna=skipna, ).expand_dims(dim={"quantile": [quantile]}) ) @@ -1511,9 +1637,14 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: return self._compute_chunk_impl( - debias=False, forecast=forecast, truth=truth, region=region + debias=False, + forecast=forecast, + truth=truth, + region=region, + skipna=skipna, ) @@ -1569,9 +1700,14 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: return self._compute_chunk_impl( - debias=True, forecast=forecast, truth=truth, region=region + debias=True, + forecast=forecast, + truth=truth, + region=region, + skipna=skipna, ) @@ -1608,6 +1744,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: if isinstance(self.threshold, thresholds.Threshold): @@ -1624,7 +1761,8 @@ def compute_chunk( truth_probability = xr.where(truth > threshold, 1.0, 0.0) forecast_probability = xr.where(forecast > threshold, 1.0, 0.0) ensemble_forecast_probability = forecast_probability.mean( - self.ensemble_dim, skipna=False + self.ensemble_dim, + skipna=skipna, ) ignorance_score = -xr.where( truth_probability, @@ -1635,6 +1773,7 @@ def compute_chunk( _spatial_average( ignorance_score, region=region, + skipna=skipna, ).expand_dims(dim={"quantile": [quantile]}) ) @@ -1692,6 +1831,7 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Spatially averaged RPS of the ensemble forecast.""" rps_per_threshold = [] @@ -1701,12 +1841,15 @@ def compute_chunk( truth_ecdf = xr.where(truth < threshold, 1.0, 0.0) forecast_ecdf = xr.where(forecast < threshold, 1.0, 0.0) ensemble_forecast_ecdf = forecast_ecdf.mean( - self.ensemble_dim, skipna=False + self.ensemble_dim, + skipna=skipna, ) rps_per_threshold.append((ensemble_forecast_ecdf - truth_ecdf) ** 2) - return _spatial_average(sum(rps_per_threshold), region=region) + return _spatial_average( + sum(rps_per_threshold), region=region, skipna=skipna + ) @dataclasses.dataclass @@ -1727,6 +1870,8 @@ class RankHistogram(EnsembleMetric): will be (num_bins - 1) / (N num_bins²). Since the expected value is 1 / num_bins, the relative error is Sqrt(variance) / expected = Sqrt((num_bins - 1) / N). + + NaN values are treated as larger than any other. The skipna kwarg is ignored. """ def __init__( @@ -1770,8 +1915,18 @@ def compute_chunk( forecast: xr.Dataset, truth: xr.Dataset, region: t.Optional[Region] = None, + skipna: bool = False, ) -> xr.Dataset: """Computes one-hot encoding of rank on a chunk of forecast/truth.""" + + if skipna and ( + any(truth[v].isnull().any() for v in truth) + or any(forecast[v].isnull().any() for v in forecast) + ): + logging.warning( + "NaN values detected in truth or forecast. skipna=True but it will be" + " ignored." + ) # Create a fake ensemble member for truth. This is for concatenation. truth_realization = forecast[self.ensemble_dim].data.min() - 1 truth = truth.assign_coords({self.ensemble_dim: truth_realization}) diff --git a/weatherbench2/metrics_test.py b/weatherbench2/metrics_test.py index cc8b74c..dad2440 100644 --- a/weatherbench2/metrics_test.py +++ b/weatherbench2/metrics_test.py @@ -69,7 +69,7 @@ def test_get_lat_weights(self): (np.sqrt(3) - 1) / 2, 1 - np.sqrt(3) / 2, ] - ) + ) # fmt: skip expected = xr.DataArray(expected_data, coords=ds.coords, dims=['latitude']) xr.testing.assert_allclose(expected, weights) @@ -102,7 +102,7 @@ def test_wind_vector_rmse(self): [0, -4, 1], coords={'level': forecast.level} ), } - ) + ) # fmt: skip truth_modifier = xr.Dataset( { 'u_component_of_wind': xr.DataArray( @@ -112,7 +112,7 @@ def test_wind_vector_rmse(self): [0, 4, 1], coords={'level': forecast.level} ), } - ) + ) # fmt: skip forecast = forecast + forecast_modifier truth = truth + truth_modifier @@ -188,7 +188,7 @@ class CRPSTest(parameterized.TestCase): ) def test_vs_brute_force(self, ensemble_size): truth, forecast = get_random_truth_and_forecast(ensemble_size=ensemble_size) - expected_crps = _crps_brute_force(forecast, truth) + expected_crps = _crps_brute_force(forecast, truth, skipna=False) xr.testing.assert_allclose( expected_crps['score'], @@ -203,7 +203,9 @@ def test_ensemble_size_1_gives_mae(self): truth, forecast = get_random_truth_and_forecast(ensemble_size=1) expected_skill = metrics._spatial_average( - abs(truth - forecast.isel({metrics.REALIZATION: 0})) + abs(truth - forecast.isel({metrics.REALIZATION: 0})), + region=None, + skipna=False, ) xr.testing.assert_allclose( @@ -219,7 +221,11 @@ def test_ensemble_size_1_gives_mae(self): expected_skill, # Spread = 0 ) - def test_nan_forecasts_result_in_nan_crps(self): + @parameterized.parameters( + (True,), + (False,), + ) + def test_nan_forecasts_result_in_nan_crps(self, skipna): truth, forecast = get_random_truth_and_forecast( variables=['geopotential', 'temperature'], ensemble_size=7 ) @@ -231,11 +237,14 @@ def test_nan_forecasts_result_in_nan_crps(self): data={'geopotential': new_values, 'temperature': forecast.temperature} ) - crps = metrics.CRPS().compute_chunk(forecast, truth) + crps = metrics.CRPS().compute_chunk(forecast, truth, skipna=skipna) - # The only NaN geopotential is in the very first place. + # The only possible NaN geopotential is in the very first place. score_values = crps.geopotential.values.copy() - self.assertTrue(np.isnan(score_values[0, 0, 0])) + if skipna: + self.assertFalse(np.isnan(score_values[0, 0, 0])) + else: + self.assertTrue(np.isnan(score_values[0, 0, 0])) score_values[0, 0, 0] = 0 # Replace the NaN self.assertTrue(np.all(np.isfinite(score_values))) @@ -243,7 +252,10 @@ def test_nan_forecasts_result_in_nan_crps(self): self.assertTrue(np.all(np.isfinite(crps.temperature.values))) xr.testing.assert_allclose( - crps, _crps_brute_force(forecast, truth)['score'] + crps, + _crps_brute_force(forecast, truth, skipna=skipna)['score'], + rtol=1e-4, + atol=1e-4, ) def test_repeated_forecasts_are_okay(self): @@ -257,7 +269,7 @@ def test_repeated_forecasts_are_okay(self): crps = metrics.CRPS().compute_chunk(forecast, truth) xr.testing.assert_allclose( - crps, _crps_brute_force(forecast, truth)['score'] + crps, _crps_brute_force(forecast, truth, skipna=False)['score'] ) @@ -370,7 +382,7 @@ def test_gaussian_brier_score(self, error, expected_1, expected_2): forecast = schema.mock_forecast_data( variables_2d=['2m_temperature', '2m_temperature_std'], lead_stop='1 day', - **kwargs + **kwargs, ) truth = schema.mock_truth_data(variables_2d=['2m_temperature'], **kwargs) truth = truth + 1.0 @@ -426,7 +438,7 @@ def test_gaussian_ignorance_score(self, error, expected): forecast = schema.mock_forecast_data( variables_2d=['2m_temperature', '2m_temperature_std'], lead_stop='1 day', - **kwargs + **kwargs, ) truth = schema.mock_truth_data(variables_2d=['2m_temperature'], **kwargs) truth = truth + 1.0 @@ -476,7 +488,7 @@ def test_gaussian_rps(self, error, expected): forecast = schema.mock_forecast_data( variables_2d=['2m_temperature', '2m_temperature_std'], lead_stop='1 day', - **kwargs + **kwargs, ) truth = schema.mock_truth_data(variables_2d=['2m_temperature'], **kwargs) q_1 = ( @@ -521,7 +533,9 @@ class RankHistogramTest(parameterized.TestCase): dict(testcase_name='EnsembleSize2', ensemble_size=2), dict(testcase_name='EnsembleSize9_NumBins5', ensemble_size=9, num_bins=5), ) - def test_well_and_mis_calibrated(self, ensemble_size, num_bins=None): + def test_well_and_mis_calibrated( + self, ensemble_size, num_bins=None, frac_nan=None + ): num_bins = ensemble_size + 1 if num_bins is None else num_bins # Forecast and truth come from same distribution truth, forecast = get_random_truth_and_forecast( @@ -531,6 +545,9 @@ def test_well_and_mis_calibrated(self, ensemble_size, num_bins=None): time_stop='2019-12-10', levels=(0, 1, 2, 3, 4), ) + if frac_nan: + truth = test_utils.insert_nan(truth, frac_nan=frac_nan, seed=0) + forecast = test_utils.insert_nan(forecast, frac_nan=frac_nan, seed=1) # level=0 is well calibrated # level=1,2 are under/over dispersed @@ -817,21 +834,23 @@ def test_versus_large_ensemble(self): ) -def _crps_brute_force(forecast: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: +def _crps_brute_force( + forecast: xr.Dataset, truth: xr.Dataset, skipna: bool +) -> xr.Dataset: """The eFAIR version of CRPS from Zamo & Naveau over a chunk of data.""" # This version is simple enough that we can use it as a reference. def _l1_norm(x): - return metrics._spatial_average(abs(x)) + return metrics._spatial_average(abs(x), region=None, skipna=skipna) n_ensemble = forecast.dims[metrics.REALIZATION] - skill = _l1_norm(truth - forecast).mean(metrics.REALIZATION, skipna=False) + skill = _l1_norm(truth - forecast).mean(metrics.REALIZATION, skipna=skipna) if n_ensemble == 1: spread = xr.zeros_like(skill) else: spread = _l1_norm( forecast - forecast.rename({metrics.REALIZATION: 'dummy'}) - ).mean(dim=(metrics.REALIZATION, 'dummy'), skipna=False) * ( + ).mean(dim=(metrics.REALIZATION, 'dummy'), skipna=skipna) * ( n_ensemble / (n_ensemble - 1) ) @@ -950,7 +969,11 @@ def test_ensemble_brier_score(self, error, ens_delta, expected): result['2m_temperature'].values, expected_arr, rtol=1e-4 ) - def test_nan_propagates_to_output(self): + @parameterized.parameters( + (True,), + (False,), + ) + def test_nan_propagates_to_output_unless_skipna(self, skipna): kwargs = { 'variables_2d': ['2m_temperature'], 'variables_3d': [], @@ -970,11 +993,13 @@ def test_nan_propagates_to_output(self): truth = truth + 1.0 forecast_with_nan = xr.where( - forecast.prediction_timedelta < forecast.prediction_timedelta[-1], + forecast.latitude == forecast.latitude[0], np.nan, forecast, ) - truth_with_nan = xr.where(truth.time < truth.time[-1], np.nan, truth) + truth_with_nan = xr.where( + truth.longitude == truth.longitude[0], np.nan, truth + ) climatology_mean = truth.isel(time=0, drop=True).expand_dims(dayofyear=366) climatology_std = ( @@ -993,9 +1018,14 @@ def test_nan_propagates_to_output(self): # When forecast has nan in prediction_timedelta, only that timedelta will # be NaN. result = metrics.EnsembleBrierScore(threshold).compute( - forecast_with_nan, truth + forecast_with_nan, + truth, + skipna=skipna, ) - expected_arr = np.array([[np.nan, 0.0]]) + if skipna: + expected_arr = np.array([[0.0, 0.0]]) + else: + expected_arr = np.array([[np.nan, np.nan]]) np.testing.assert_allclose( result['2m_temperature'].values, expected_arr, @@ -1005,9 +1035,14 @@ def test_nan_propagates_to_output(self): # When truth has nan, the final average over times means the entire # score is NaN. result = metrics.EnsembleBrierScore(threshold).compute( - forecast, truth_with_nan + forecast, + truth_with_nan, + skipna=skipna, ) - expected_arr = np.array([[np.nan, np.nan]]) + if skipna: + expected_arr = np.array([[0.0, 0.0]]) + else: + expected_arr = np.array([[np.nan, np.nan]]) np.testing.assert_allclose( result['2m_temperature'].values, expected_arr, @@ -1016,9 +1051,8 @@ def test_nan_propagates_to_output(self): class DebiasedEnsembleBrierScoreTest(parameterized.TestCase): - def test_versus_large_ensemble(self): + def test_versus_large_ensemble_and_ensure_skipna_works(self): large_ensemble_size = 1000 - threshold = 1.0 # truth, forecast are both Normal(0, 1) truth, forecast = get_random_truth_and_forecast( @@ -1056,6 +1090,27 @@ def test_versus_large_ensemble(self): threshold ).compute(small_ensemble_forecast, truth) + # Get some variants using a bit of NaN values + data_size = np.prod(list(small_ensemble_forecast.sizes.values())) + frac_nan = 0.0005 + self.assertGreater( + data_size * frac_nan, + 40, + msg=f'{frac_nan=} was so small this test is trivial', + ) + small_ensemble_forecast_w_nan = test_utils.insert_nan( + small_ensemble_forecast, frac_nan=frac_nan, seed=0 + ) + truth_w_nan = test_utils.insert_nan(truth, frac_nan=frac_nan, seed=1) + bs_small_ensemble_w_nan = metrics.EnsembleBrierScore(threshold).compute( + small_ensemble_forecast_w_nan, + truth_w_nan, + skipna=True, + ) + bs_debiased_small_ensemble_w_nan = metrics.DebiasedEnsembleBrierScore( + threshold + ).compute(small_ensemble_forecast_w_nan, truth_w_nan, skipna=True) + # Make sure the test is not trivial by showing that without debiasing we get # the expected bias. Since truth/forecast are drawn from the correct # distribution, we know the variance, and then @@ -1071,12 +1126,25 @@ def test_versus_large_ensemble(self): total_points = np.prod(list(truth.dims.values())) stderr = np.sqrt(variance / total_points) + # Large ensemble gives the same result as small ensemble, since we debias. xr.testing.assert_allclose( bs_large_ensemble.mean(), bs_debiased_small_ensemble.mean(), atol=4 * stderr, ) + # The small fraction of NaN values barely changes the results. + xr.testing.assert_allclose( + bs_small_ensemble_w_nan.mean(), + bs_small_ensemble.mean(), + atol=4 * stderr, + ) + xr.testing.assert_allclose( + bs_debiased_small_ensemble_w_nan.mean(), + bs_debiased_small_ensemble.mean(), + atol=4 * stderr, + ) + class EnsembleIgnoranceScoreTest(parameterized.TestCase): diff --git a/weatherbench2/test_utils.py b/weatherbench2/test_utils.py index 071c8ea..f8e3b34 100644 --- a/weatherbench2/test_utils.py +++ b/weatherbench2/test_utils.py @@ -16,6 +16,7 @@ from typing import Any import numpy as np +import xarray as xr def assert_strictly_decreasing(x: Any, axis=-1, err_msg: str = '') -> None: @@ -46,3 +47,17 @@ def assert_negative(x: Any, err_msg: str = '') -> None: 0, err_msg=f'Was not negative. {err_msg}', ) + + +def insert_nan( + ds: xr.Dataset, frac_nan: float = 0.1, seed=802701 +) -> xr.Dataset: + """Copy ds with NaN inserted in every variable.""" + ds = ds.copy() + rng = np.random.RandomState(seed) + for name in ds: + data = ds[name].data + mask = rng.rand(*data.shape) < frac_nan + data = np.where(mask, np.nan, data) + ds[name] = ds[name].copy(data=data) + return ds