Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No public description #167

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def _metric_and_region_loop(
) -> xr.Dataset:
"""Compute metric results looping over metrics and regions in eval config."""
# Compute derived variables
logging.info('Starting _metric_and_region_loop')
for name, dv in eval_config.derived_variables.items():
logging.info(f'Logging: derived_variable {name!r}: {dv}')
forecast[name] = dv.compute(forecast)
Expand Down
2 changes: 2 additions & 0 deletions weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
from scipy import stats
from weatherbench2 import thresholds
from weatherbench2 import utils
from weatherbench2.regions import Region
import xarray as xr

Expand Down Expand Up @@ -705,6 +706,7 @@ def compute_chunk(
return _pointwise_crps_skill(forecast, truth, self.ensemble_dim)


@utils.id_lru_cache(maxsize=1)
def _pointwise_crps_spread(
forecast: xr.Dataset, truth: xr.Dataset, ensemble_dim: str
) -> xr.Dataset:
Expand Down
48 changes: 47 additions & 1 deletion weatherbench2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Utility function for WeatherBench2."""
from typing import Callable, Union
import functools
from typing import Callable, Hashable, Union

import fsspec
import numpy as np
Expand Down Expand Up @@ -292,3 +293,48 @@ def random_like(dataset: xr.Dataset, seed: int = 0) -> xr.Dataset:
return dataset.copy(
data={k: rs.normal(size=v.shape) for k, v in dataset.items()}
)


def id_lru_cache(maxsize: int = 5):
"""Like functools.lru_cache but uses argument id for non-hashables.

Warning: This is not threadsafe. Multiple threads reading/writing to the cache
results in inconsistent behavior.

Args:
maxsize: Maximum size of cache.

Returns:
Decorator to make a function into a caching function.
"""

def hashid(x):
if isinstance(x, Hashable):
return hash(x)
return id(x)

def decorating_function(func):
cache = {}

@functools.wraps(func)
def wrapper(*args, **kwargs):
key = tuple(hashid(a) for a in args) + tuple(
(k, hashid(v)) for k, v in kwargs.items()
)
# Python dicts are ordered in the sense that if keys = list(my_dict.key())
# then keys[0] is the first key added, and keys[-1] is the most recently
# added.
if key in cache:
# Move cache[key] to position -1 since it is most recently used.
value = cache[key]
del cache[key]
cache[key] = value
else:
if len(cache) >= maxsize:
cache.pop(list(cache)[0]) # Pop first item added to dictionary.
cache[key] = func(*args, **kwargs)
return cache[key]

return wrapper

return decorating_function
35 changes: 35 additions & 0 deletions weatherbench2/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
from absl.testing import absltest
import numpy as np
from weatherbench2 import schema
from weatherbench2 import utils
import xarray
Expand Down Expand Up @@ -67,5 +68,39 @@ def testProbabilisticClimatology(self):
self.assertEqual(clim['2m_temperature'].sizes, expected_sizes)


class IdLRUCacheTest(absltest.TestCase):

def test_handles_non_hashable_args_and_kwargs(self):

@utils.id_lru_cache(maxsize=2)
def func(x: np.ndarray, y: np.ndarray, b: float = 1):
return np.sum(x + y * b)

# Use 3 sets of arrays so we are sure to cycle through the size 2 cache.
with self.subTest('First set of arrays'):
x = np.array([1.0, 2.0, 3.0])
y = x + 2
b = 1.3
expected = np.sum(x + y * b)
for _ in range(4):
self.assertEqual(expected, func(x, y, b=b))

with self.subTest('Second set of arrays'):
x = np.array([0.0, -2.0, 0.123])
y = np.array([10.0, -1.0, 3])
b = 10.3
expected = np.sum(x + y * b)
for _ in range(4):
self.assertEqual(expected, func(x, y, b=b))

with self.subTest('Third set of arrays'):
x = np.array([0.0, -20.0])
y = np.array([10.0, -11.0])
b = -1234
expected = np.sum(x + y * b)
for _ in range(4):
self.assertEqual(expected, func(x, y, b=b))


if __name__ == '__main__':
absltest.main()
Loading