Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 639917081
  • Loading branch information
langmore authored and Weatherbench2 authors committed Jun 3, 2024
1 parent 104c944 commit abfc3ff
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
6 changes: 3 additions & 3 deletions scripts/slice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@
None,
help=(
'Comma delimited list of variables to drop. If empty, drop no'
' variables.'
' variables. List may include data variables or coords.'
),
)

KEEP_VARIABLES = flags.DEFINE_list(
'keep_variables',
None,
help=(
'Comma delimited list of variables to keep. If empty, use'
'Comma delimited list of data variables to keep. If empty, use'
' --drop_variables to determine which variables to keep'
),
)
Expand Down Expand Up @@ -143,7 +143,7 @@ def main(argv: abc.Sequence[str]) -> None:
ds, input_chunks = xbeam.open_zarr(INPUT_PATH.value)

if DROP_VARIABLES.value:
ds = ds[[v for v in ds if v not in DROP_VARIABLES.value]]
ds = ds.drop_vars(DROP_VARIABLES.value)
elif KEEP_VARIABLES.value:
ds = ds[KEEP_VARIABLES.value]

Expand Down
1 change: 1 addition & 0 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ def _evaluate(
forecast_pipeline |= 'TemporalMean' >> xbeam.Mean(
dim='init_time' if self.data_config.by_init else 'time',
fanout=self.fanout,
skipna=False,
)

return forecast_pipeline
Expand Down
19 changes: 15 additions & 4 deletions weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def compute(
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
skipna: bool = False,
) -> xr.Dataset:
"""Evaluate this metric on datasets with full temporal coverages."""
if "time" in forecast.dims:
Expand All @@ -124,7 +125,10 @@ def compute(
raise ValueError(
f"Forecast has neither valid_time or init_time dimension {forecast}"
)
return self.compute_chunk(forecast, truth, region=region).mean(avg_dim)
return self.compute_chunk(forecast, truth, region=region).mean(
avg_dim,
skipna=skipna,
)


def _spatial_average(
Expand Down Expand Up @@ -553,9 +557,10 @@ def compute(
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
skipna: bool = False,
) -> xr.Dataset:
"""Evaluate this metric on datasets with full temporal coverages."""
result = super().compute(forecast, truth, region=region)
result = super().compute(forecast, truth, region=region, skipna=skipna)
return result.assign_attrs(ensemble_size=forecast[self.ensemble_dim].size)


Expand Down Expand Up @@ -1406,8 +1411,14 @@ def _compute_chunk_impl(
for threshold in threshold_seq:
quantile = threshold.quantile
threshold = threshold.compute(truth)
truth_probability = xr.where(truth > threshold, 1.0, 0.0)
forecast_probability = xr.where(forecast > threshold, 1.0, 0.0)
truth_probability = xr.where(
truth.isnull(),
np.nan,
xr.where(truth > threshold, 1.0, 0.0),
)
forecast_probability = xr.where(
forecast.isnull(), np.nan, xr.where(forecast > threshold, 1.0, 0.0)
)
if debias:
mse_of_probabilities = _debiased_ensemble_mean_mse(
forecast_probability,
Expand Down
63 changes: 63 additions & 0 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,69 @@ def test_ensemble_brier_score(self, error, ens_delta, expected):
result['2m_temperature'].values, expected_arr, rtol=1e-4
)

def test_nan_propagates_to_output(self):
kwargs = {
'variables_2d': ['2m_temperature'],
'variables_3d': [],
'time_start': '2022-01-01',
'time_stop': '2022-01-03',
}
forecast = schema.mock_forecast_data(
ensemble_size=4, lead_stop='1 day', **kwargs
)
forecast = (
# Use settings from test_ensemble_brier_score that result in score=0.
forecast
+ 1.0
+ 0.1 * np.arange(-2, 2).reshape((4, 1, 1, 1, 1))
)
truth = schema.mock_truth_data(**kwargs)
truth = truth + 1.0

forecast_with_nan = xr.where(
forecast.prediction_timedelta < forecast.prediction_timedelta[-1],
np.nan,
forecast,
)
truth_with_nan = xr.where(truth.time < truth.time[-1], np.nan, truth)

climatology_mean = truth.isel(time=0, drop=True).expand_dims(dayofyear=366)
climatology_std = (
truth.isel(time=0, drop=True)
.expand_dims(
dayofyear=366,
)
.rename({'2m_temperature': '2m_temperature_std'})
)
climatology = xr.merge([climatology_mean, climatology_std])
threshold = thresholds.GaussianQuantileThreshold(
climatology=climatology, quantile=0.2
)

with self.subTest('forecast has nan'):
# When forecast has nan in prediction_timedelta, only that timedelta will
# be NaN.
result = metrics.EnsembleBrierScore(threshold).compute(
forecast_with_nan, truth
)
expected_arr = np.array([[np.nan, 0.0]])
np.testing.assert_allclose(
result['2m_temperature'].values,
expected_arr,
)

with self.subTest('truth has nan'):
# When truth has nan, the final average over times means the entire
# score is NaN.
result = metrics.EnsembleBrierScore(threshold).compute(
forecast, truth_with_nan
)
expected_arr = np.array([[np.nan, np.nan]])
np.testing.assert_allclose(
result['2m_temperature'].values,
expected_arr,
)


class DebiasedEnsembleBrierScoreTest(parameterized.TestCase):

Expand Down

0 comments on commit abfc3ff

Please sign in to comment.