diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index b56d5b85..2e7cf8b6 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -663,28 +663,20 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: command.append("--cpu_offload_optimizer_pin_memory") print(f"\033[92mRunning command: {' '.join(command)}\033[0m") - process = None + process = StreamablePopen( + f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log", + command, + ) + print("\033[91mTerminating process 🤖\033[0m") + process.terminate() try: - process = StreamablePopen( - f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log", - command, - ) - - except KeyboardInterrupt: - print("Process interrupted by user") - except Exception as e: - print(f"An error occurred: {str(e)}") - finally: - if "process" not in locals() or process is None: - return - - print("\033[91mTerminating process 🤖\033[0m") - process.terminate() - try: - process.wait(timeout=60) - except subprocess.TimeoutExpired: - print("\033[91mProcess did not terminate in time, killing it.\033[0m") - process.kill() + rc = process.wait(timeout=60) + if rc: + raise RuntimeError(f"Training process exited with code {rc}") + except subprocess.TimeoutExpired as e: + print("\033[91mProcess did not terminate in time, killing it.\033[0m") + process.kill() + raise RuntimeError("Training process timed out on exit.") from e if __name__ == "__main__":