From 1675e3f78a87f979e4271c168eeedd8948e81f14 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 27 Aug 2024 10:36:44 -0400 Subject: [PATCH] add support for CPU and MPS do not use distributed when not available, instead use CPU or MPS. This entails a few changes: --device is now a valid flag to the library since `ilab` can pass CPU, MPS, or default to cuda when using CPU or MPS, do not initialize DS, instead put the model on the device and initialize `Adafactor` optimizer which is more efficient and than Adam based one inside of `train` add logic for handling if torch.cuda.is_available and torch.distributed.is_initialized() we dont use distributed torch on consumer systems the train loop needs some custom step and loss logic for a LlamaForCausalLM model, add that in when using CPU or MPS we are always world_size == 1 and local_rank == 0 Signed-off-by: Charlie Doern --- src/instructlab/training/main_ds.py | 208 +++++++++++++----- src/instructlab/training/multipack_sampler.py | 10 +- src/instructlab/training/utils.py | 34 +-- 3 files changed, 177 insertions(+), 75 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 07f33728..42fe50aa 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -16,11 +16,18 @@ # pylint: disable=no-name-in-module from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM +from torch import nn from torch.distributed import ReduceOp, all_reduce from tqdm import tqdm -from transformers import AutoModelForCausalLM, get_scheduler +from transformers import ( + Adafactor, + AutoModelForCausalLM, + LlamaForCausalLM, + get_scheduler, +) import deepspeed import torch +import torch.distributed # First Party from instructlab.training import config @@ -83,7 +90,7 @@ def get_ds_config(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption return ds_config -def setup_model(args, tokenizer, train_loader, grad_accum): +def setup_model(args, tokenizer, train_loader, grad_accum, device): bnb_config = None if args.lora_r > 0 and args.lora_quant_bits == 4: # Third Party @@ -250,25 +257,35 @@ def make_inputs_require_grad(module, input, output): ) # pylint: disable=unbalanced-tuple-unpacking - model, _, _, lr_scheduler = deepspeed.initialize( - model=model, - optimizer=optimizer, - config=get_ds_config( - world_size=torch.distributed.get_world_size(), - samples_per_gpu=args.samples_per_gpu, - grad_accum=grad_accum, - opts=DeepSpeedOptions( - cpu_offload_optimizer=args.cpu_offload_optimizer, - cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, - cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory, - save_samples=args.save_samples_ds, + optimizer = None + if device.type == "cuda": + model, _, _, lr_scheduler = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=get_ds_config( + world_size=torch.distributed.get_world_size(), + samples_per_gpu=args.samples_per_gpu, + grad_accum=grad_accum, + opts=DeepSpeedOptions( + cpu_offload_optimizer=args.cpu_offload_optimizer, + cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, + cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory, + save_samples=args.save_samples_ds, + ), ), - ), - lr_scheduler=lr_scheduler, - dist_init_required=True, - ) - # model = torch.compile(model) - return model + lr_scheduler=lr_scheduler, + dist_init_required=True, + ) + else: + # If we are using CPU or MPS just place model on that device + # also, initialize Adafactor, a Transformers Optimizer designed to use less resources. + # if we use AdamW here most people will always run out of RAM + model = model.to(device) + optimizer = Adafactor( + model.parameters(), lr=1e-5, scale_parameter=True, relative_step=False + ) + model.gradient_checkpointing_enable() + return model, optimizer # this function is to check if the checkpoint provided can be resumed @@ -331,7 +348,9 @@ def maybe_resume_training(args, model): return model -def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): +def train( + args, model, tokenizer, train_loader, grad_accum, metric_logger, device, optimizer +): model.train() global_step = 1 @@ -359,7 +378,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): ) for epoch in range(args.num_epochs): - torch.distributed.barrier() + if torch.cuda.is_available(): + torch.distributed.barrier() if args.sampler in ("multipack"): train_loader.batch_sampler.set_epoch(epoch) elif args.sampler in ("distributed"): @@ -370,7 +390,12 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): if local_rank == 0: inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}") - aggregated_values = torch.zeros(3, dtype=torch.float32).to(local_rank) + if not torch.cuda.is_available(): + aggregated_values = torch.zeros(3, dtype=torch.float32, device=device).to( + device=device + ) + else: + aggregated_values = torch.zeros(3, dtype=torch.float16).to(local_rank) for batch in train_loader: if global_step <= args.last_step: # in the case of resuming, last_step > 0 @@ -384,7 +409,10 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): aggregated_values[1] = len(batch["input_ids"]) if not args.is_granite: for k in batch: - batch[k] = batch[k].to(local_rank) + if torch.cuda.is_available(): + batch[k] = batch[k].to(local_rank) + else: + batch[k] = batch[k].to(device="cpu") output = model( **batch, @@ -394,7 +422,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): aggregated_values[2] = loss.item() - all_reduce(aggregated_values, op=ReduceOp.SUM) + if torch.cuda.is_available() and torch.distributed.is_initialized(): + all_reduce(aggregated_values, op=ReduceOp.SUM) num_loss_counted_tokens = aggregated_values[0] loss = ( @@ -404,32 +433,65 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): print( f"\033[93mPer-token loss scaled by world size: {(loss/num_loss_counted_tokens) * world_size}\033[0m" ) - print( - f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}" - ) - - model.backward(loss) - model.step() + if torch.cuda.is_available(): + rank = torch.distributed.get_rank() + else: + rank = 0 + print(f"Epoch: {epoch}, Step: {global_step}, Rank: {rank}, loss = {loss}") + + # If using a LlamaForCausalLM model (single device CPU, GPU, or MPS) then we cannot use the DS .backward, .step from the model_engine + # instead, use the AdaFactor Optimizer's zero_grad, the loss.backward() and step the optimizer itself. + if torch.cuda.is_available(): + model.backward(loss) + model.step() + else: + optimizer.zero_grad() + loss.backward() + optimizer.step() if local_rank == 0: elapsed_time = time.time() - start overall_throughput = args.samples_per_gpu * world_size / elapsed_time - current_lr = model.lr_scheduler.get_last_lr()[0] - cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) - cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] - global_grad_norm = model.get_global_grad_norm() + cuda_malloc_retries = 0 + cuda_mem_allocated = 0 + if torch.cuda.is_available(): + cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) + cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + norm = None + if not isinstance(model, LlamaForCausalLM): + global_grad_norm = model.get_global_grad_norm() + norm = model.optimizer.single_partition_of_fp32_groups[0].norm() + current_lr = model.lr_scheduler.get_last_lr()[0] + else: + global_grad_norm = nn.utils.clip_grad_norm_( + model.parameters(), max_norm=float("inf") + ) + lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=10, gamma=0.1 + ) + fp32_params = [ + param.data + for param in model.parameters() + if param.requires_grad + ] + norm = torch.norm(fp32_params[0]) + # for name, param in model.named_parameters(): + # if param.requires_grad: + # fp32_weights = param.data + # fp32_norm = torch.norm(fp32_weights) + # print(f"Norm of {name}: {fp32_norm.item()}") + current_lr = lr_scheduler.get_last_lr()[0] global_grad_norm = ( float(global_grad_norm) if global_grad_norm is not None else None ) - weight_norm = float( - model.optimizer.single_partition_of_fp32_groups[0].norm() - ) + + weight_norm = float(norm) metric_logger.log_sync( { "epoch": epoch, "step": global_step, - "rank": torch.distributed.get_rank(), + "rank": rank, "loss": loss.item(), "overall_throughput": overall_throughput, "lr": current_lr, @@ -470,7 +532,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): global_step += 1 if local_rank == 0: inner_pb.update(1) - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() if args.checkpoint_at_epoch: save_hf_format_ds( @@ -507,13 +570,28 @@ def main(args): # device = torch.device("cuda", args.local_rank) #### distributed init ##### - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - args.local_rank = int(os.environ["LOCAL_RANK"]) - deepspeed.init_distributed(timeout=timedelta(minutes=30)) - args.global_rank = torch.distributed.get_rank() - tensor = torch.ByteTensor([False]).cuda() - torch.distributed.all_reduce(tensor) - torch.distributed.barrier() + world_size = 1 + device = None + if not torch.cuda.is_available(): + if ( + args.device == "mps" + and torch.backends.mps.is_available() + and torch.backend.mps.is_built() + ): + device = torch.device("mps") + else: + device = torch.device("cpu") + args.local_rank = 0 + args.global_rank = 0 + elif torch.distributed.is_available(): + world_size = torch.distributed.get_world_size() + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + args.local_rank = int(os.environ["LOCAL_RANK"]) + deepspeed.init_distributed(timeout=timedelta(minutes=10)) + args.global_rank = torch.distributed.get_rank() + tensor = torch.ByteTensor([False]).cuda() + torch.distributed.all_reduce(tensor) + torch.distributed.barrier() dataset = setup_dataset( args.data_path, @@ -523,7 +601,7 @@ def main(args): try: packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum( - num_gpus=torch.distributed.get_world_size(), + num_gpus=world_size, avg_sample_len=dataset.get_lengths().mean(), effective_batch_size=args.effective_batch_size, max_batch_len_per_gpu=args.max_batch_len, @@ -542,9 +620,7 @@ def main(args): grad_accum = 1 args.sampler = "distributed" - args.samples_per_gpu = ( - args.effective_batch_size // grad_accum // torch.distributed.get_world_size() - ) + args.samples_per_gpu = args.effective_batch_size // grad_accum // world_size train_loader = setup_dataloader( dataset, @@ -580,7 +656,7 @@ def main(args): if args.local_rank == 0: metric_logger.log_sync( { - "num_gpus": torch.distributed.get_world_size(), + "num_gpus": world_size, "avg_sample_len": dataset.get_lengths().mean(), "effective_batch_size": args.effective_batch_size, "max_batch_len_per_gpu": args.max_batch_len, @@ -592,13 +668,24 @@ def main(args): } ) - model = setup_model(args, tokenizer, train_loader, grad_accum) - model = maybe_resume_training(args, model) - - train(args, model, tokenizer, train_loader, grad_accum, metric_logger) + model, optimizer = setup_model(args, tokenizer, train_loader, grad_accum, device) + if device.type == "cuda": + model = maybe_resume_training(args, model) + + train( + args, + model, + tokenizer, + train_loader, + grad_accum, + metric_logger, + device, + optimizer, + ) - torch.distributed.barrier() - torch.distributed.destroy_process_group() + if torch.cuda.is_available() and torch.distributed.is_available(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() # public API @@ -705,6 +792,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.deepspeed_options.cpu_offload_optimizer_pin_memory: command.append("--cpu_offload_optimizer_pin_memory") + if torch_args.nproc_per_node == 1: + command.append("--standalone") + print(f"\033[92mRunning command: {' '.join(command)}\033[0m") process = None try: @@ -831,6 +921,8 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ), ) parser.add_argument("--disable_flash_attn", action="store_true") + parser.add_argument("--standalone", action="store_true") + parser.add_argument("--device", type=str, default="cuda") args = parser.parse_args() set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 71d1def2..1db8ab33 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -30,7 +30,6 @@ from torch.utils.data import Sampler import numba import numpy as np -import torch import torch.distributed as dist @@ -67,11 +66,16 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu): The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches. """ + num_replicas = 1 + rank = 0 + if dist.is_initialized(): + num_replicas = dist.get_world_size() + rank = dist.get_rank() sampler = MultipackDistributedBatchSampler( batch_max_length=num_tokens_per_gpu, lengths=dataset.get_lengths(), - num_replicas=torch.distributed.get_world_size(), - rank=torch.distributed.get_rank(), + num_replicas=num_replicas, + rank=rank, seed=seed, padding=True, ) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 96d86c57..e2283db1 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -28,8 +28,6 @@ ) from rich.logging import RichHandler from safetensors.torch import save_file -from torch import distributed as dist -from torch.distributed import get_rank, is_initialized from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, apply_activation_checkpointing, @@ -37,6 +35,7 @@ ) import numpy as np import torch +import torch.distributed import torch.nn.functional as F @@ -480,7 +479,7 @@ class UniversalCheckpointArgs: with open(latest_file, "w") as f: f.write(step_folder) - dist.barrier() + torch.distributed.barrier() log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds") @@ -508,26 +507,28 @@ def ensure_loadable_granite_checkpoint( # Assumption: tmpdir should be accessible by all ranks, even those # in different nodes tmpdir = Path(tmpdir) / f"tmp.{group_rank}" - if os.path.exists(tmpdir) and (not dist.is_initialized() or local_rank == 0): + if os.path.exists(tmpdir) and ( + not torch.distributed.is_initialized() or local_rank == 0 + ): # need to delete if it exists because import doesnt like it to shutil.rmtree(tmpdir, ignore_errors=True) - if not dist.is_initialized() or local_rank == 0: + if not torch.distributed.is_initialized() or local_rank == 0: import_from_huggingface(model_name_or_path, tmpdir) - if dist.is_initialized(): + if torch.distributed.is_initialized(): # the first barrier is to wait for local rank 0 to finish converting the model # and place into tmpdir - dist.barrier() + torch.distributed.barrier() # return tmpdir out for loading yield tmpdir - if dist.is_initialized(): + if torch.distributed.is_initialized(): # the second barrier is to wait for all the models to finish loading - dist.barrier() + torch.distributed.barrier() - if not dist.is_initialized() or local_rank == 0: + if not torch.distributed.is_initialized() or local_rank == 0: # at this point, we can be confident that the tmpdir is no longer needed shutil.rmtree(tmpdir, ignore_errors=True) @@ -603,7 +604,7 @@ def get_caller(num_frames=1): def log_rank_0(msg, include_caller=False, rank=None, to_print=False): if rank is None: - rank = get_rank() if is_initialized() else 0 + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 if rank <= 0: if include_caller: msg = f"{get_caller(num_frames=2)}: {msg}" @@ -632,7 +633,11 @@ def save_hf_format_ds( convert_granite=True, is_lora=False, ): - model_to_save = model.module + if torch.cuda.is_available(): + model_to_save = model.module + else: + # if not using DS, the model is an actual model not model_engine + model_to_save = model log_rank_0( f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m", to_print=True, @@ -647,7 +652,7 @@ def save_hf_format_ds( else: WEIGHTS_NAME = "pytorch_model.bin" output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}" - if torch.distributed.get_rank() == 0: + if not torch.cuda.is_available() or torch.distributed.get_rank() == 0: if is_lora: model_to_save.merge_adapter() @@ -686,7 +691,8 @@ def save_hf_format_ds( if is_lora: model_to_save.unmerge_adapter() - dist.barrier() + if torch.cuda.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True) log_rank_0(f"saving took {time.time() - start} seconds")