Skip to content

Commit

Permalink
DS fix, continued (#3145)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 12, 2024
1 parent 42be235 commit a427548
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/deepspeed/test_deepspeed_multiple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def test_config_reference_update(self):
accelerator = Accelerator(deepspeed_plugin=ds_plugins)
from transformers.integrations.deepspeed import deepspeed_config

# Note that these have `auto` values being set so we need to adjust
assert accelerator.deepspeed_plugin is zero2
zero2.deepspeed_config["train_micro_batch_size_per_gpu"] = 1
zero2.deepspeed_config.pop("train_batch_size")
assert deepspeed_config() == accelerator.deepspeed_plugin.hf_ds_config.config

accelerator.state.select_deepspeed_plugin("zero3")
Expand Down Expand Up @@ -173,6 +176,6 @@ def test_prepare_multiple_models_zero3_inference(self):
@slow
def test_train_multiple_models(self):
self.test_file_path = self.test_scripts_folder / "test_ds_multiple_model.py"
args = ["--num_processes=2", "--num_machines=1", "--main_process_port=10999", str(self.test_file_path)]
args = ["--num_processes=2", "--num_machines=1", "--main_process_port=0", str(self.test_file_path)]
args = self.parser.parse_args(args)
launch_command(args)

0 comments on commit a427548

Please sign in to comment.