Skip to content

Commit

Permalink
Ran build_cleaner. Also some internal dependency changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630191048
  • Loading branch information
langmore authored and Weatherbench2 authors committed May 7, 2024
1 parent a050b1a commit f878c52
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 21 deletions.
17 changes: 15 additions & 2 deletions scripts/compute_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@
None,
help='Beam CombineFn fanout. Might be required for large dataset.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


# pylint: disable=expression-not-assigned
Expand Down Expand Up @@ -120,7 +125,10 @@ def main(argv: list[str]):

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
chunked = root | xbeam.DatasetToChunks(
source_dataset, source_chunks, split_vars=True
source_dataset,
source_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)

if weights is not None:
Expand All @@ -131,7 +139,12 @@ def main(argv: list[str]):
(
chunked
| xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value)
| xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
target_chunks,
num_threads=NUM_THREADS.value,
)
)


Expand Down
12 changes: 10 additions & 2 deletions scripts/compute_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@
'precipitation variable. In mm.'
),
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


class Quantile:
Expand Down Expand Up @@ -353,7 +358,10 @@ def _compute_seeps(kv):
pcoll = (
root
| xbeam.DatasetToChunks(
obs, input_chunks, split_vars=True, num_threads=16
obs,
input_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)
| 'RechunkIn'
>> xbeam.Rechunk( # pytype: disable=wrong-arg-types
Expand Down Expand Up @@ -416,7 +424,7 @@ def _compute_seeps(kv):
OUTPUT_PATH.value,
template=clim_template,
zarr_chunks=output_chunks,
num_threads=16,
num_threads=NUM_THREADS.value,
)
)

Expand Down
17 changes: 15 additions & 2 deletions scripts/compute_derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@
MAX_MEM_GB = flags.DEFINE_integer(
'max_mem_gb', 1, help='Max memory for rechunking in GB.'
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)

RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')

Expand Down Expand Up @@ -226,7 +231,12 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool:
# so that with and without rechunking can be computed in parallel
pcoll = (
root
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False)
| xbeam.DatasetToChunks(
source_dataset,
source_chunks,
split_vars=False,
num_threads=NUM_THREADS.value,
)
| beam.MapTuple(
lambda k, v: ( # pylint: disable=g-long-lambda
k,
Expand Down Expand Up @@ -274,7 +284,10 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool:

# Combined
_ = pcoll | xbeam.ChunksToZarr(
OUTPUT_PATH.value, template, source_chunks, num_threads=16
OUTPUT_PATH.value,
template,
source_chunks,
num_threads=NUM_THREADS.value,
)


Expand Down
19 changes: 17 additions & 2 deletions scripts/compute_ensemble_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
'2020-12-31',
help='ISO 8601 timestamp (inclusive) at which to stop evaluation',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


# pylint: disable=expression-not-assigned
Expand Down Expand Up @@ -88,9 +93,19 @@ def main(argv: list[str]):
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
(
root
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=True)
| xbeam.DatasetToChunks(
source_dataset,
source_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)
| xbeam.Mean(REALIZATION_NAME.value, skipna=False)
| xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
target_chunks,
num_threads=NUM_THREADS.value,
)
)


Expand Down
10 changes: 9 additions & 1 deletion scripts/compute_statistical_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
RECHUNK_ITEMSIZE = flags.DEFINE_integer(
'rechunk_itemsize', 4, help='Itemsize for rechunking.'
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


def moment_reduce(
Expand Down Expand Up @@ -143,7 +148,9 @@ def main(argv: list[str]) -> None:

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
# Read
pcoll = root | xbeam.DatasetToChunks(obs, input_chunks, split_vars=True)
pcoll = root | xbeam.DatasetToChunks(
obs, input_chunks, split_vars=True, num_threads=NUM_THREADS.value
)

# Branches to compute statistical moments
pcolls = []
Expand Down Expand Up @@ -174,6 +181,7 @@ def main(argv: list[str]) -> None:
OUTPUT_PATH.value,
template=output_template,
zarr_chunks=output_chunks,
num_threads=NUM_THREADS.value,
)
)

Expand Down
17 changes: 15 additions & 2 deletions scripts/compute_zonal_energy_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
None,
help='Beam CombineFn fanout. Might be required for large dataset.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)

RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')

Expand Down Expand Up @@ -196,7 +201,12 @@ def main(argv: list[str]) -> None:
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False)
| xbeam.DatasetToChunks(
source_dataset,
source_chunks,
split_vars=False,
num_threads=NUM_THREADS.value,
)
| beam.MapTuple(
lambda k, v: ( # pylint: disable=g-long-lambda
k,
Expand All @@ -207,7 +217,10 @@ def main(argv: list[str]) -> None:
| beam.MapTuple(_strip_offsets)
| xbeam.Mean(AVERAGING_DIMS.value, fanout=FANOUT.value)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value, template, output_chunks, num_threads=16
OUTPUT_PATH.value,
template,
output_chunks,
num_threads=NUM_THREADS.value,
)
)

Expand Down
16 changes: 14 additions & 2 deletions scripts/convert_init_to_valid_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@
INPUT_PATH = flags.DEFINE_string('input_path', None, help='zarr inputs')
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='zarr outputs')
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)

TIME = 'time'
DELTA = 'prediction_timedelta'
Expand Down Expand Up @@ -254,7 +259,9 @@ def main(argv: list[str]) -> None:
source_ds.indexes[INIT],
)
)
p |= xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True)
p |= xarray_beam.DatasetToChunks(
source_ds, input_chunks, split_vars=True, num_threads=NUM_THREADS.value
)
if input_chunks != split_chunks:
p |= xarray_beam.SplitChunks(split_chunks)
p |= beam.FlatMapTuple(
Expand All @@ -266,7 +273,12 @@ def main(argv: list[str]) -> None:
p = (p, padding) | beam.Flatten()
if input_chunks != split_chunks:
p |= xarray_beam.ConsolidateChunks(output_chunks)
p |= xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks)
p |= xarray_beam.ChunksToZarr(
OUTPUT_PATH.value,
template,
output_chunks,
num_threads=NUM_THREADS.value,
)


if __name__ == '__main__':
Expand Down
6 changes: 6 additions & 0 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@
None,
help='Beam CombineFn fanout. Might be required for large dataset.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write Zarr in parallel per worker.',
)


def _wind_vector_error(err_type: str):
Expand Down Expand Up @@ -623,6 +628,7 @@ def main(argv: list[str]) -> None:
runner=RUNNER.value,
input_chunks=INPUT_CHUNKS.value,
fanout=FANOUT.value,
num_threads=NUM_THREADS.value,
argv=argv,
)
else:
Expand Down
10 changes: 9 additions & 1 deletion scripts/expand_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@
None,
help='Desired integer chunk size. If not set, inferred from input chunks.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')


Expand Down Expand Up @@ -149,7 +154,10 @@ def main(argv: list[str]) -> None:
| beam.Reshuffle()
| beam.FlatMap(select_climatology, climatology, times, base_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks
OUTPUT_PATH.value,
template=template,
zarr_chunks=output_chunks,
num_threads=NUM_THREADS.value,
)
)

Expand Down
19 changes: 17 additions & 2 deletions scripts/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@
LONGITUDE_NAME = flags.DEFINE_string(
'longitude_name', 'longitude', help='Name of longitude dimension in dataset'
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')


Expand Down Expand Up @@ -135,11 +140,21 @@ def main(argv):
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True)
| xarray_beam.DatasetToChunks(
source_ds,
input_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)
| 'Regrid'
>> beam.MapTuple(lambda k, v: (k, regridder.regrid_dataset(v)))
| xarray_beam.ConsolidateChunks(output_chunks)
| xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks)
| xarray_beam.ChunksToZarr(
OUTPUT_PATH.value,
template,
output_chunks,
num_threads=NUM_THREADS.value,
)
)


Expand Down
4 changes: 3 additions & 1 deletion scripts/resample_in_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@
help='Add suffix "_mean" to variable name when computing the mean.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads', None, help='Number of chunks to load in parallel per worker.'
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)
TIME_DIM = flags.DEFINE_string(
'time_dim', 'time', help='Name for the time dimension to slice data on.'
Expand Down
4 changes: 3 additions & 1 deletion scripts/slice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@
'runner', None, help='Beam runner. Use DirectRunner for local execution.'
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads', None, help='Number of chunks to load in parallel per worker.'
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


Expand Down
Loading

0 comments on commit f878c52

Please sign in to comment.