From f878c52820fe79d43a35384cef69703948933755 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Thu, 2 May 2024 15:07:26 -0700 Subject: [PATCH] Ran build_cleaner. Also some internal dependency changes. PiperOrigin-RevId: 630191048 --- scripts/compute_averages.py | 17 +++++++++++++++-- scripts/compute_climatology.py | 12 ++++++++++-- scripts/compute_derived_variables.py | 17 +++++++++++++++-- scripts/compute_ensemble_mean.py | 19 +++++++++++++++++-- scripts/compute_statistical_moments.py | 10 +++++++++- scripts/compute_zonal_energy_spectrum.py | 17 +++++++++++++++-- scripts/convert_init_to_valid_time.py | 16 ++++++++++++++-- scripts/evaluate.py | 6 ++++++ scripts/expand_climatology.py | 10 +++++++++- scripts/regrid.py | 19 +++++++++++++++++-- scripts/resample_in_time.py | 4 +++- scripts/slice_dataset.py | 4 +++- weatherbench2/evaluation.py | 24 +++++++++++++++++++++--- 13 files changed, 154 insertions(+), 21 deletions(-) diff --git a/scripts/compute_averages.py b/scripts/compute_averages.py index 23f9088..9ca2460 100644 --- a/scripts/compute_averages.py +++ b/scripts/compute_averages.py @@ -86,6 +86,11 @@ None, help='Beam CombineFn fanout. Might be required for large dataset.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) # pylint: disable=expression-not-assigned @@ -120,7 +125,10 @@ def main(argv: list[str]): with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: chunked = root | xbeam.DatasetToChunks( - source_dataset, source_chunks, split_vars=True + source_dataset, + source_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, ) if weights is not None: @@ -131,7 +139,12 @@ def main(argv: list[str]): ( chunked | xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value) - | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks) + | xbeam.ChunksToZarr( + OUTPUT_PATH.value, + template, + target_chunks, + num_threads=NUM_THREADS.value, + ) ) diff --git a/scripts/compute_climatology.py b/scripts/compute_climatology.py index c0a3c4d..6269d04 100644 --- a/scripts/compute_climatology.py +++ b/scripts/compute_climatology.py @@ -120,6 +120,11 @@ 'precipitation variable. In mm.' ), ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) class Quantile: @@ -353,7 +358,10 @@ def _compute_seeps(kv): pcoll = ( root | xbeam.DatasetToChunks( - obs, input_chunks, split_vars=True, num_threads=16 + obs, + input_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, ) | 'RechunkIn' >> xbeam.Rechunk( # pytype: disable=wrong-arg-types @@ -416,7 +424,7 @@ def _compute_seeps(kv): OUTPUT_PATH.value, template=clim_template, zarr_chunks=output_chunks, - num_threads=16, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/compute_derived_variables.py b/scripts/compute_derived_variables.py index 25b6aae..30a49e2 100644 --- a/scripts/compute_derived_variables.py +++ b/scripts/compute_derived_variables.py @@ -116,6 +116,11 @@ MAX_MEM_GB = flags.DEFINE_integer( 'max_mem_gb', 1, help='Max memory for rechunking in GB.' ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -226,7 +231,12 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool: # so that with and without rechunking can be computed in parallel pcoll = ( root - | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False) + | xbeam.DatasetToChunks( + source_dataset, + source_chunks, + split_vars=False, + num_threads=NUM_THREADS.value, + ) | beam.MapTuple( lambda k, v: ( # pylint: disable=g-long-lambda k, @@ -274,7 +284,10 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool: # Combined _ = pcoll | xbeam.ChunksToZarr( - OUTPUT_PATH.value, template, source_chunks, num_threads=16 + OUTPUT_PATH.value, + template, + source_chunks, + num_threads=NUM_THREADS.value, ) diff --git a/scripts/compute_ensemble_mean.py b/scripts/compute_ensemble_mean.py index 736ea8d..5e6030a 100644 --- a/scripts/compute_ensemble_mean.py +++ b/scripts/compute_ensemble_mean.py @@ -61,6 +61,11 @@ '2020-12-31', help='ISO 8601 timestamp (inclusive) at which to stop evaluation', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) # pylint: disable=expression-not-assigned @@ -88,9 +93,19 @@ def main(argv: list[str]): with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: ( root - | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=True) + | xbeam.DatasetToChunks( + source_dataset, + source_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, + ) | xbeam.Mean(REALIZATION_NAME.value, skipna=False) - | xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks) + | xbeam.ChunksToZarr( + OUTPUT_PATH.value, + template, + target_chunks, + num_threads=NUM_THREADS.value, + ) ) diff --git a/scripts/compute_statistical_moments.py b/scripts/compute_statistical_moments.py index 31ab6bb..281bfbd 100644 --- a/scripts/compute_statistical_moments.py +++ b/scripts/compute_statistical_moments.py @@ -37,6 +37,11 @@ RECHUNK_ITEMSIZE = flags.DEFINE_integer( 'rechunk_itemsize', 4, help='Itemsize for rechunking.' ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) def moment_reduce( @@ -143,7 +148,9 @@ def main(argv: list[str]) -> None: with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: # Read - pcoll = root | xbeam.DatasetToChunks(obs, input_chunks, split_vars=True) + pcoll = root | xbeam.DatasetToChunks( + obs, input_chunks, split_vars=True, num_threads=NUM_THREADS.value + ) # Branches to compute statistical moments pcolls = [] @@ -174,6 +181,7 @@ def main(argv: list[str]) -> None: OUTPUT_PATH.value, template=output_template, zarr_chunks=output_chunks, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/compute_zonal_energy_spectrum.py b/scripts/compute_zonal_energy_spectrum.py index d1173ce..ee1595f 100644 --- a/scripts/compute_zonal_energy_spectrum.py +++ b/scripts/compute_zonal_energy_spectrum.py @@ -96,6 +96,11 @@ None, help='Beam CombineFn fanout. Might be required for large dataset.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -196,7 +201,12 @@ def main(argv: list[str]) -> None: with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: _ = ( root - | xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False) + | xbeam.DatasetToChunks( + source_dataset, + source_chunks, + split_vars=False, + num_threads=NUM_THREADS.value, + ) | beam.MapTuple( lambda k, v: ( # pylint: disable=g-long-lambda k, @@ -207,7 +217,10 @@ def main(argv: list[str]) -> None: | beam.MapTuple(_strip_offsets) | xbeam.Mean(AVERAGING_DIMS.value, fanout=FANOUT.value) | xbeam.ChunksToZarr( - OUTPUT_PATH.value, template, output_chunks, num_threads=16 + OUTPUT_PATH.value, + template, + output_chunks, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/convert_init_to_valid_time.py b/scripts/convert_init_to_valid_time.py index 3e7f84c..446d61f 100644 --- a/scripts/convert_init_to_valid_time.py +++ b/scripts/convert_init_to_valid_time.py @@ -102,6 +102,11 @@ INPUT_PATH = flags.DEFINE_string('input_path', None, help='zarr inputs') OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='zarr outputs') RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) TIME = 'time' DELTA = 'prediction_timedelta' @@ -254,7 +259,9 @@ def main(argv: list[str]) -> None: source_ds.indexes[INIT], ) ) - p |= xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True) + p |= xarray_beam.DatasetToChunks( + source_ds, input_chunks, split_vars=True, num_threads=NUM_THREADS.value + ) if input_chunks != split_chunks: p |= xarray_beam.SplitChunks(split_chunks) p |= beam.FlatMapTuple( @@ -266,7 +273,12 @@ def main(argv: list[str]) -> None: p = (p, padding) | beam.Flatten() if input_chunks != split_chunks: p |= xarray_beam.ConsolidateChunks(output_chunks) - p |= xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks) + p |= xarray_beam.ChunksToZarr( + OUTPUT_PATH.value, + template, + output_chunks, + num_threads=NUM_THREADS.value, + ) if __name__ == '__main__': diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 153d968..be6fbd0 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -249,6 +249,11 @@ None, help='Beam CombineFn fanout. Might be required for large dataset.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write Zarr in parallel per worker.', +) def _wind_vector_error(err_type: str): @@ -623,6 +628,7 @@ def main(argv: list[str]) -> None: runner=RUNNER.value, input_chunks=INPUT_CHUNKS.value, fanout=FANOUT.value, + num_threads=NUM_THREADS.value, argv=argv, ) else: diff --git a/scripts/expand_climatology.py b/scripts/expand_climatology.py index 699e9c8..c330328 100644 --- a/scripts/expand_climatology.py +++ b/scripts/expand_climatology.py @@ -72,6 +72,11 @@ None, help='Desired integer chunk size. If not set, inferred from input chunks.', ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -149,7 +154,10 @@ def main(argv: list[str]) -> None: | beam.Reshuffle() | beam.FlatMap(select_climatology, climatology, times, base_chunks) | xbeam.ChunksToZarr( - OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks + OUTPUT_PATH.value, + template=template, + zarr_chunks=output_chunks, + num_threads=NUM_THREADS.value, ) ) diff --git a/scripts/regrid.py b/scripts/regrid.py index 79c0e6f..54c070a 100644 --- a/scripts/regrid.py +++ b/scripts/regrid.py @@ -78,6 +78,11 @@ LONGITUDE_NAME = flags.DEFINE_string( 'longitude_name', 'longitude', help='Name of longitude dimension in dataset' ) +NUM_THREADS = flags.DEFINE_integer( + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', +) RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') @@ -135,11 +140,21 @@ def main(argv): with beam.Pipeline(runner=RUNNER.value, argv=argv) as root: _ = ( root - | xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True) + | xarray_beam.DatasetToChunks( + source_ds, + input_chunks, + split_vars=True, + num_threads=NUM_THREADS.value, + ) | 'Regrid' >> beam.MapTuple(lambda k, v: (k, regridder.regrid_dataset(v))) | xarray_beam.ConsolidateChunks(output_chunks) - | xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks) + | xarray_beam.ChunksToZarr( + OUTPUT_PATH.value, + template, + output_chunks, + num_threads=NUM_THREADS.value, + ) ) diff --git a/scripts/resample_in_time.py b/scripts/resample_in_time.py index 2a69954..dd0e0c9 100644 --- a/scripts/resample_in_time.py +++ b/scripts/resample_in_time.py @@ -119,7 +119,9 @@ help='Add suffix "_mean" to variable name when computing the mean.', ) NUM_THREADS = flags.DEFINE_integer( - 'num_threads', None, help='Number of chunks to load in parallel per worker.' + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', ) TIME_DIM = flags.DEFINE_string( 'time_dim', 'time', help='Name for the time dimension to slice data on.' diff --git a/scripts/slice_dataset.py b/scripts/slice_dataset.py index bd64bc7..9c7b18a 100644 --- a/scripts/slice_dataset.py +++ b/scripts/slice_dataset.py @@ -96,7 +96,9 @@ 'runner', None, help='Beam runner. Use DirectRunner for local execution.' ) NUM_THREADS = flags.DEFINE_integer( - 'num_threads', None, help='Number of chunks to load in parallel per worker.' + 'num_threads', + None, + help='Number of chunks to read/write in parallel per worker.', ) diff --git a/weatherbench2/evaluation.py b/weatherbench2/evaluation.py index f797209..c5c5e7b 100644 --- a/weatherbench2/evaluation.py +++ b/weatherbench2/evaluation.py @@ -509,6 +509,7 @@ class _SaveOutputs(beam.PTransform): eval_name: str data_config: config.Data output_format: str + num_threads: Optional[int] = None def _write_netcdf(self, datasets: list[xr.Dataset]) -> xr.Dataset: combined = xr.combine_by_coords(datasets) @@ -529,7 +530,9 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection: output_path = _get_output_path( self.data_config, self.eval_name, self.output_format ) - return pcoll | xbeam.ChunksToZarr(output_path) + return pcoll | xbeam.ChunksToZarr( + output_path, num_threads=self.num_threads + ) else: raise ValueError(f'unrecogonized data format: {self.output_format}') @@ -551,6 +554,7 @@ class _EvaluateAllMetrics(beam.PTransform): data_config: config.Data input_chunks: abc.Mapping[str, int] fanout: Optional[int] = None + num_threads: Optional[int] = None def _evaluate_chunk( self, @@ -662,12 +666,14 @@ def _evaluate( forecast, self.input_chunks, split_vars=False, + num_threads=self.num_threads, ) | beam.MapTuple(self._sel_corresponding_truth_chunk, truth=truth) else: forecast_pipeline = xbeam.DatasetToChunks( [forecast, truth], self.input_chunks, split_vars=False, + num_threads=self.num_threads, ) if self.eval_config.evaluate_climatology: @@ -723,6 +729,7 @@ def evaluate_with_beam( input_chunks: abc.Mapping[str, int], runner: str, fanout: Optional[int] = None, + num_threads: Optional[int] = None, argv: Optional[list[str]] = None, ) -> None: """Run evaluation with a Beam pipeline. @@ -750,6 +757,7 @@ def evaluate_with_beam( input_chunks: Chunking of input datasets. runner: Beam runner. fanout: Beam CombineFn fanout. + num_threads: Number of threads to use for reading/writing data. argv: Other arguments to pass into the Beam pipeline. """ @@ -760,8 +768,18 @@ def evaluate_with_beam( root | f'evaluate_{eval_name}' >> _EvaluateAllMetrics( - eval_name, eval_config, data_config, input_chunks, fanout=fanout + eval_name, + eval_config, + data_config, + input_chunks, + fanout=fanout, + num_threads=num_threads, ) | f'save_{eval_name}' - >> _SaveOutputs(eval_name, data_config, eval_config.output_format) + >> _SaveOutputs( + eval_name, + data_config, + eval_config.output_format, + num_threads=num_threads, + ) )