From 8d1f3a6d836ef5e0ef8284ef5e7006348abeef50 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 19 May 2024 21:39:30 -0700 Subject: [PATCH 01/34] wip --- config/harness/harness_nano.yaml | 21 +++ pyproject.toml | 1 + src/levanter/data/loader.py | 15 +- src/levanter/eval_harness.py | 287 +++++++++++++++++++++++++++++++ src/levanter/models/lm_model.py | 27 +++ src/levanter/trainer.py | 4 + src/levanter/utils/jax_utils.py | 36 ++++ 7 files changed, 378 insertions(+), 13 deletions(-) create mode 100644 config/harness/harness_nano.yaml create mode 100644 src/levanter/eval_harness.py diff --git a/config/harness/harness_nano.yaml b/config/harness/harness_nano.yaml new file mode 100644 index 000000000..c34e2a83a --- /dev/null +++ b/config/harness/harness_nano.yaml @@ -0,0 +1,21 @@ +tokenizer: "gpt2" +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_parallelism: -1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" diff --git a/pyproject.toml b/pyproject.toml index 393a56ed8..45b6e2c94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "pydantic<3", # temporary pin until Ray supports pydantic 2.0 "rich~=13.0", "filelock~=3.13", + "lm-eval==0.4.2" ] [tool.hatch.build] diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 03a64f7e4..b3ac832bd 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -1,5 +1,4 @@ import abc -import functools import logging from collections import defaultdict from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union @@ -21,6 +20,7 @@ from levanter.data.dataset import ShardableDataset from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape from levanter.utils.background_iterable import BackgroundIterable +from levanter.utils.jax_utils import stack_tree from levanter.utils.py_utils import non_caching_cycle @@ -89,7 +89,7 @@ def get_local_batch(begin: int, end: int) -> List[Array]: individual_datums = get_batch_items(begin, end) - device_batch = _stack_tree(self.Batch.name, individual_datums) + device_batch = stack_tree(self.Batch, individual_datums, pad_to_batch_size=False) batch_leaves = jtu.tree_leaves(device_batch) stacked_local_batch[key] = batch_leaves @@ -211,17 +211,6 @@ def local_batch_size(self) -> int: return self.batch_size // self.num_data_process_groups -@functools.partial(jax.jit, static_argnums=(0,)) -def _stack_tree(batch_name, individual_datums): - def _stack_leaves_unchecked(*leaves): - if is_named_array(leaves[0]): - return hax.stack(batch_name, leaves) - else: - return jnp.stack(leaves) - - return jax.tree_map(_stack_leaves_unchecked, *individual_datums, is_leaf=is_named_array) - - class ReplicatedBatchLoader(BatchLoader[Ex]): """A batch loader that creates batches without sharded data loading. All examples are loaded on all machines and then sharded. This is useful if you have a small dataset and want to make a single pass over it. diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py new file mode 100644 index 000000000..18fd29a52 --- /dev/null +++ b/src/levanter/eval_harness.py @@ -0,0 +1,287 @@ +# Code for running https://github.com/EleutherAI/lm-evaluation-harness inside Levanter runs +# References: +# https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py +# https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6 +import dataclasses +import enum +import logging +import warnings +from dataclasses import dataclass +from functools import cached_property +from typing import List, Tuple + +import equinox as eqx +import jax +import jax.numpy as jnp +import transformers +from jax.experimental.multihost_utils import broadcast_one_to_all +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval import evaluator, tasks + +import haliax as hax +import levanter.config +from haliax.nn import cross_entropy_loss +from haliax.partitioning import fsdp, round_axis_for_partitioning +from levanter.checkpoint import load_checkpoint +from levanter.data import batched + +from levanter.models.lm_model import LmConfig, LmHeadModel, LmExample +from levanter.trainer import TrainerConfig +from levanter.utils.jax_utils import stack_tree, use_cpu_device +from levanter.utils.tree_utils import inference_mode + +logger = logging.getLogger(__name__) + + +# Ok this is a bit complicated to do because it's distributed systems and that's always hard. +# The idea is that we want to pass an LM adaptor to the harness, and then the harness will call the LM adaptor +# with a request, which we'll format, shard, and send to the model. The model will then return the result to the harness +# which will then return the result to the user. + +# As we so often do, we will coordinate execution through JAX itself. + +# Process 0 will: +# - Pass an adaptor to the eval harness +# - The eval harness will call the adaptor with a request +# - When a request comes in, it will call broadcast_one_to_all with a (REQUEST_TYPE, request) to send the request +# - It then invokes the model with the request and returns the result to the eval harness +# - When finished, it will call broadcast_one_to_all with a (FINISHED_TYPE, result) to send the result + +# Process 1..n will: +# - Wait for a (REQUEST_TYPE, request) broadcast +# - if FINISHED_TYPE, break +# - Invoke the model with the request +# - loop + + +class RequestType: + LOG_LIKELIHOOD = 0 + GENERATE_UNTIL = 1 + LOG_LIKELIHOOD_ROLLING = 2 + FINISHED = 3 + + +class InternalLMAdaptor: + + def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel): + self.EvalBatch = EvalBatch + self.model = model + + self.dummy_example: LmExample = LmExample.causal( + tokens=hax.zeros(model.Pos, dtype=jnp.int32), + ) + + def _eval_loglikelihood(model: LmHeadModel, example: LmExample): + logits = model(example.tokens) + + targets = hax.roll(example.tokens, -1, axis=model.Pos.name) + target_y = hax.nn.one_hot(targets, model.Vocab, dtype=logits.dtype) + loss = cross_entropy_loss(logits, model.Vocab, target_y, where=example.loss_mask, reduction_axis=model.Pos) + # to tell if we got the right answer, we want to check that argmax(logits) == tokens wherever loss_mask is 1 + pred_targets = hax.argmax(logits, axis=model.Vocab) + correct = hax.all(hax.equal(pred_targets, targets) | (not example.loss_mask), axis=model.Pos) + + return loss, correct + + # no sharded outputs + self._jit_loglikelihood = hax.named_jit(_eval_loglikelihood, out_axis_resources={}) + + def loglikelihood(self, examples: LmExample) -> tuple[hax.NamedArray, hax.NamedArray]: + return self._jit_loglikelihood(self.model, examples) + + def do_request(self, request_type, example): + if request_type == RequestType.LOG_LIKELIHOOD: + return self.loglikelihood(example) + elif request_type == RequestType.GENERATE_UNTIL: + raise NotImplementedError() + elif request_type == RequestType.LOG_LIKELIHOOD_ROLLING: + raise NotImplementedError() + else: + raise ValueError(f"Invalid request type {request_type}") + + def worker_loop(self): + dummy_example = self.dummy_example + while True: + request_type, request = broadcast_one_to_all((RequestType.FINISHED, dummy_example), is_source=False) + + if request_type == RequestType.FINISHED: + break + + result = self.do_request(request_type, request) + del result + + def finish(self): + assert jax.process_index() == 0 + broadcast_one_to_all((RequestType.FINISHED, self.dummy_example), is_source=True) + + +class LevanterHarnessLM(LM): + + def __init__(self, adaptor: InternalLMAdaptor, tokenizer): + super().__init__() + self.adaptor = adaptor + self.tokenizer = tokenizer + + def _stack_batch(examples): + return stack_tree(self.adaptor.EvalBatch, examples, pad_to_batch_size=True) + + self._stack_batch_jit = hax.named_jit(_stack_batch) + + def _stack_batch(self, examples): + with use_cpu_device(): + return self._stack_batch_jit(examples) + + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: + """ + Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + Args: + requests: + + Returns: + + """ + + Pos = self.adaptor.model.Pos + + contexts = self.tokenizer([req.args[0] for req in requests])["input_ids"] + completions = self.tokenizer([req.args[1] for req in requests])["input_ids"] + + examples: list[LmExample] = [] + with use_cpu_device(): + for context, completion in zip(contexts, completions): + context, completion = self._truncate_or_pad(context, completion) + context, completion = hax.named(context, Pos.name), hax.named(completion, Pos.name) + example = LmExample.from_prompt_and_completion(context, completion, ignore_id=self.tokenizer.pad_token_id) + examples.append(example) + + result = [] + for batch in batched(examples, self.adaptor.EvalBatch.size): + batch_example = self._stack_batch(batch) + out_lls, out_correct = self._dispatch(RequestType.LOG_LIKELIHOOD, batch_example) + result.extend(zip(out_lls, out_correct)) + + # skip padding + result = result[:len(examples)] + + print(contexts) + print(completions) + + return result + + + def _truncate_or_pad(self, context, completion): + max_len = self.adaptor.model.Pos.size + if len(completion) > max_len: + warnings.warn(f"Completion is longer than max length {max_len}. Truncating.") + completion = completion[:max_len] + pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id + + if len(context) + len(completion) > max_len: + context = context[-(max_len - len(completion)):] + else: + # right pad with padding token + context = context + [pad_token_id] * (max_len - len(context) - len(completion)) + + return context, completion + + def _dispatch(self, request_type, request): + broadcast_one_to_all((request_type, request), is_source=True) + return self.adaptor.do_request(request_type, request) + + + def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: + raise NotImplementedError() + + def generate_until(self, requests) -> List[str]: + raise NotImplementedError() + + + +def run_lm_eval_harness(model, tokenizer, EvalBatch, max_examples=None): + adaptor = InternalLMAdaptor(EvalBatch, model) + harness = LevanterHarnessLM(adaptor, tokenizer) + tasks_to_run = tasks.get_task_dict([ + # "lambada", + # "piqa", + "hellaswag", + # "winogrande", + # "mathqa", + # "pubmedqa", + # "boolq", + # "cb", + # "copa", + # "multirc", + # "record", + # "wic", + # "wsc", + ]) + if jax.process_index() == 0: + try: + outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) + finally: + adaptor.finish() + else: + adaptor.worker_loop() + outputs = {} + + return outputs + +@dataclass +class EvalHarnessConfig: + tokenizer: str + checkpoint_path: str + trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig) + model: LmConfig = dataclasses.field(default_factory=LmConfig) + + @property + def EvalBatch(self): + return self.trainer.EvalBatch + + @cached_property + def the_tokenizer(self): + return transformers.AutoTokenizer.from_pretrained(self.tokenizer) + + +@levanter.config.main +def run_eval_harness_main(config: EvalHarnessConfig): + config.trainer.initialize() + tokenizer = config.the_tokenizer + + compute_axis_mapping = config.trainer.compute_axis_mapping + parameter_axis_mapping = config.trainer.parameter_axis_mapping + + with config.trainer.device_mesh, hax.axis_mapping(parameter_axis_mapping): + key = jax.random.PRNGKey(0) + + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(hax.Axis("vocab", vocab_size), compute_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + total = config.trainer.max_eval_batches + + if total is not None: + total = total * config.trainer.EvalBatch.size + + # initialize the model + if config.checkpoint_path is not None: + # initialize the model + with use_cpu_device(): + model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) + # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model + model = load_checkpoint(model, config.checkpoint_path, subpath="model") + model = inference_mode(model, True) + + model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) + + logger.info(f"Running LM eval harness....") + run_lm_eval_harness(model, tokenizer, config.EvalBatch, max_examples=total) + else: + raise ValueError("No checkpoint path provided") + + +if __name__ == '__main__': + run_eval_harness_main() \ No newline at end of file diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index b68a33f2b..579a34813 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -46,6 +46,33 @@ def causal( attn_mask = AttentionMask.causal() return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + @staticmethod + def from_prompt_and_completion(prompt: hax.NamedArray, completion: hax.NamedArray, *, ignore_id: Optional[int] = None, + all_causal: bool = True) -> "LmExample": + Pos = prompt.axes[0] + tokens = hax.concatenate(Pos.name, [prompt, completion]) + tokens_Pos = tokens.resolve_axis(Pos.name) + + # mask out the prompt tokens + loss_mask = hax.arange(tokens_Pos) >= Pos.size + # also mask out the last token + loss_mask *= (1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)) + + if ignore_id is not None: + ignore_mask = tokens != ignore_id + loss_mask *= ignore_mask + + if all_causal: + attn_mask = AttentionMask.causal() + else: + # causal just for the completion part. We don't have a special structured mask for this, so we just + raise NotImplementedError("Not implemented yet") + + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + + + + # TODO: for some reason, mypy doesn't like the discover_packages_path argument? class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 00e7b280d..067309f09 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -592,10 +592,14 @@ class TrainerConfig: @property def TrainBatch(self): + if self.batch_size <= 0: + raise ValueError("batch_size must be positive. Did you call initialize?") return Axis("batch", self.train_batch_size) @property def EvalBatch(self): + if self.eval_batch_size <= 0: + raise ValueError("eval_batch_size must be positive. Did you call initialize?") return Axis("batch", self.eval_batch_size) @property diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 7ff09e44c..7127075b1 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -1,4 +1,5 @@ import contextlib +import functools import json import warnings from dataclasses import fields @@ -11,6 +12,9 @@ from jax.sharding import PositionalSharding from jaxtyping import PRNGKeyArray, PyTree +import haliax as hax +from haliax import AxisSelector, is_named_array + from haliax.jax_utils import is_jax_array_like @@ -279,3 +283,35 @@ def estimated_free_device_memory(device) -> Optional[float]: in_use = stats.get("bytes_in_use", 0) return (limit - in_use) // (1024.0**3) + + +@functools.partial(jax.jit, static_argnums=(0), static_argnames=("batch", "pad_to_batch_size")) +def stack_tree(batch: AxisSelector, individual_datums: list[X], *, pad_to_batch_size: bool) -> X: + """ + Stacks a tree of NamedArrays or arrays into a single array. NamedArrays get a new axis with the name batch_name, + while regular arrays are stacked normally. + + Args: + batch: Axis or str name of the new axis. + individual_datums: The tree of NamedArrays or arrays to stack + pad_to_batch_size: If True, pads the arrays to the size of the batch axis (assuming batch is an axis). If False, stacks them as is. + """ + if pad_to_batch_size and not isinstance(batch, hax.Axis): + raise ValueError("pad_to_batch_size can only be used with an Axis Batch") + + if pad_to_batch_size: + missing_count = batch.size - len(individual_datums) + + def _stack_leaves_unchecked(*leaves): + if is_named_array(leaves[0]): + return hax.stack(batch.name, leaves + [hax.zeros_like(leaves[0]) for _ in range(missing_count)]) + else: + return jnp.stack(leaves + [jnp.zeros_like(leaves[0]) for _ in range(missing_count)]) + else: + def _stack_leaves_unchecked(*leaves): + if is_named_array(leaves[0]): + return hax.stack(hax.axis_name(batch), leaves) + else: + return jnp.stack(leaves) + + return jax.tree_map(_stack_leaves_unchecked, *individual_datums, is_leaf=is_named_array) From 628c525a3fa3cad0f2c9a2b37cca20b7b942f0fb Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 22 May 2024 21:59:40 -0700 Subject: [PATCH 02/34] ok it runs. garbage but it runs? --- src/levanter/checkpoint.py | 8 +- src/levanter/eval_harness.py | 127 +++++++++++++++++++------------- src/levanter/models/lm_model.py | 18 ++--- src/levanter/utils/jax_utils.py | 7 +- 4 files changed, 92 insertions(+), 68 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 7802a7f07..b517ac1c6 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -344,11 +344,15 @@ def load_checkpoint( logger.warning("Loading checkpoint in jit. This is not recommended and probably won't work.") if discover_latest: - checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore + discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore + else: + discovered_checkpoint_path = checkpoint_path - if checkpoint_path is None or not fs.exists(checkpoint_path): + if discovered_checkpoint_path is None or not fs.exists(discovered_checkpoint_path): raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") + checkpoint_path = discovered_checkpoint_path + logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 18fd29a52..0e64864e6 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -3,34 +3,36 @@ # https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py # https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6 import dataclasses -import enum +import json import logging import warnings from dataclasses import dataclass from functools import cached_property -from typing import List, Tuple +from typing import List, Optional, Tuple import equinox as eqx import jax import jax.numpy as jnp import transformers from jax.experimental.multihost_utils import broadcast_one_to_all +from lm_eval import evaluator, tasks from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from lm_eval import evaluator, tasks +from tqdm import tqdm import haliax as hax -import levanter.config from haliax.nn import cross_entropy_loss -from haliax.partitioning import fsdp, round_axis_for_partitioning +from haliax.partitioning import round_axis_for_partitioning + +import levanter.config from levanter.checkpoint import load_checkpoint from levanter.data import batched - -from levanter.models.lm_model import LmConfig, LmHeadModel, LmExample +from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import TrainerConfig from levanter.utils.jax_utils import stack_tree, use_cpu_device from levanter.utils.tree_utils import inference_mode + logger = logging.getLogger(__name__) @@ -63,10 +65,10 @@ class RequestType: class InternalLMAdaptor: - - def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel): + def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources): self.EvalBatch = EvalBatch self.model = model + self.axis_resources = axis_resources self.dummy_example: LmExample = LmExample.causal( tokens=hax.zeros(model.Pos, dtype=jnp.int32), @@ -80,12 +82,14 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample): loss = cross_entropy_loss(logits, model.Vocab, target_y, where=example.loss_mask, reduction_axis=model.Pos) # to tell if we got the right answer, we want to check that argmax(logits) == tokens wherever loss_mask is 1 pred_targets = hax.argmax(logits, axis=model.Vocab) - correct = hax.all(hax.equal(pred_targets, targets) | (not example.loss_mask), axis=model.Pos) + correct = hax.all(hax.equal(pred_targets, targets) | hax.logical_not(example.loss_mask), axis=model.Pos) return loss, correct # no sharded outputs - self._jit_loglikelihood = hax.named_jit(_eval_loglikelihood, out_axis_resources={}) + self._jit_loglikelihood = hax.named_jit( + _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} + ) def loglikelihood(self, examples: LmExample) -> tuple[hax.NamedArray, hax.NamedArray]: return self._jit_loglikelihood(self.model, examples) @@ -108,8 +112,7 @@ def worker_loop(self): if request_type == RequestType.FINISHED: break - result = self.do_request(request_type, request) - del result + self.do_request(request_type, request) def finish(self): assert jax.process_index() == 0 @@ -117,7 +120,6 @@ def finish(self): class LevanterHarnessLM(LM): - def __init__(self, adaptor: InternalLMAdaptor, tokenizer): super().__init__() self.adaptor = adaptor @@ -154,24 +156,25 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: for context, completion in zip(contexts, completions): context, completion = self._truncate_or_pad(context, completion) context, completion = hax.named(context, Pos.name), hax.named(completion, Pos.name) - example = LmExample.from_prompt_and_completion(context, completion, ignore_id=self.tokenizer.pad_token_id) + example = LmExample.from_prompt_and_completion( + context, completion, ignore_id=self.tokenizer.pad_token_id + ) examples.append(example) - result = [] - for batch in batched(examples, self.adaptor.EvalBatch.size): + result: list[tuple[float, bool]] = [] + for batch in batched(tqdm(examples, desc="examples", leave=False), self.adaptor.EvalBatch.size): batch_example = self._stack_batch(batch) out_lls, out_correct = self._dispatch(RequestType.LOG_LIKELIHOOD, batch_example) - result.extend(zip(out_lls, out_correct)) + result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # skip padding - result = result[:len(examples)] + result = result[: len(examples)] - print(contexts) - print(completions) + # print(contexts) + # print(completions) return result - def _truncate_or_pad(self, context, completion): max_len = self.adaptor.model.Pos.size if len(completion) > max_len: @@ -180,7 +183,7 @@ def _truncate_or_pad(self, context, completion): pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id if len(context) + len(completion) > max_len: - context = context[-(max_len - len(completion)):] + context = context[-(max_len - len(completion)) :] else: # right pad with padding token context = context + [pad_token_id] * (max_len - len(context) - len(completion)) @@ -191,7 +194,6 @@ def _dispatch(self, request_type, request): broadcast_one_to_all((request_type, request), is_source=True) return self.adaptor.do_request(request_type, request) - def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() @@ -199,25 +201,10 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() - -def run_lm_eval_harness(model, tokenizer, EvalBatch, max_examples=None): - adaptor = InternalLMAdaptor(EvalBatch, model) +def run_lm_eval_harness(task_spec: list[str], model, tokenizer, EvalBatch, axis_resources, max_examples=None): + adaptor = InternalLMAdaptor(EvalBatch, model, axis_resources) harness = LevanterHarnessLM(adaptor, tokenizer) - tasks_to_run = tasks.get_task_dict([ - # "lambada", - # "piqa", - "hellaswag", - # "winogrande", - # "mathqa", - # "pubmedqa", - # "boolq", - # "cb", - # "copa", - # "multirc", - # "record", - # "wic", - # "wsc", - ]) + tasks_to_run = tasks.get_task_dict(task_spec) if jax.process_index() == 0: try: outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) @@ -229,13 +216,22 @@ def run_lm_eval_harness(model, tokenizer, EvalBatch, max_examples=None): return outputs -@dataclass + +@dataclass(frozen=True) +class LmEvalHarnessConfig: + task_spec: Optional[list[str]] = None + max_examples: Optional[int] = None + + +@dataclass(frozen=True) class EvalHarnessConfig: tokenizer: str checkpoint_path: str trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig) model: LmConfig = dataclasses.field(default_factory=LmConfig) + eval_harness: LmEvalHarnessConfig = dataclasses.field(default_factory=LmEvalHarnessConfig) + @property def EvalBatch(self): return self.trainer.EvalBatch @@ -245,11 +241,28 @@ def the_tokenizer(self): return transformers.AutoTokenizer.from_pretrained(self.tokenizer) -@levanter.config.main def run_eval_harness_main(config: EvalHarnessConfig): config.trainer.initialize() tokenizer = config.the_tokenizer + task_spec = config.eval_harness.task_spec or [ + # "lambada", + # "piqa", + "hellaswag", + # "winogrande", + # "mathqa", + # "pubmedqa", + # "boolq", + # "cb", + # "copa", + # "multirc", + # "record", + # "wic", + # "wsc", + ] + + max_examples = config.eval_harness.max_examples + compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -261,11 +274,6 @@ def run_eval_harness_main(config: EvalHarnessConfig): if vocab_size != Vocab.size: logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - total = config.trainer.max_eval_batches - - if total is not None: - total = total * config.trainer.EvalBatch.size - # initialize the model if config.checkpoint_path is not None: # initialize the model @@ -277,11 +285,24 @@ def run_eval_harness_main(config: EvalHarnessConfig): model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) - logger.info(f"Running LM eval harness....") - run_lm_eval_harness(model, tokenizer, config.EvalBatch, max_examples=total) + logger.info("Running LM eval harness....") + outputs = run_lm_eval_harness( + task_spec, + model, + tokenizer, + config.EvalBatch, + axis_resources=compute_axis_mapping, + max_examples=max_examples, + ) + + logger.info("Finished running LM eval harness") + # log the results as json + with open("lm_eval_results.json", "w") as f: + + json.dump(outputs, f, indent=2) else: raise ValueError("No checkpoint path provided") -if __name__ == '__main__': - run_eval_harness_main() \ No newline at end of file +if __name__ == "__main__": + levanter.config.main(run_eval_harness_main)() diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 579a34813..d42a560ed 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -47,16 +47,17 @@ def causal( return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) @staticmethod - def from_prompt_and_completion(prompt: hax.NamedArray, completion: hax.NamedArray, *, ignore_id: Optional[int] = None, - all_causal: bool = True) -> "LmExample": - Pos = prompt.axes[0] - tokens = hax.concatenate(Pos.name, [prompt, completion]) - tokens_Pos = tokens.resolve_axis(Pos.name) + def from_prompt_and_completion( + prompt: hax.NamedArray, completion: hax.NamedArray, *, ignore_id: Optional[int] = None, all_causal: bool = True + ) -> "LmExample": + prompt_Pos = prompt.axes[0] + tokens = hax.concatenate(prompt_Pos.name, [prompt, completion]) + tokens_Pos = tokens.resolve_axis(prompt_Pos.name) # mask out the prompt tokens - loss_mask = hax.arange(tokens_Pos) >= Pos.size + loss_mask = hax.arange(tokens_Pos) >= prompt_Pos.size # also mask out the last token - loss_mask *= (1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)) + loss_mask *= 1 - hax.nn.one_hot(-1, tokens_Pos, dtype=jnp.float32) if ignore_id is not None: ignore_mask = tokens != ignore_id @@ -71,9 +72,6 @@ def from_prompt_and_completion(prompt: hax.NamedArray, completion: hax.NamedArra return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) - - - # TODO: for some reason, mypy doesn't like the discover_packages_path argument? class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 7127075b1..89ba6bed9 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -14,7 +14,6 @@ import haliax as hax from haliax import AxisSelector, is_named_array - from haliax.jax_utils import is_jax_array_like @@ -304,10 +303,12 @@ def stack_tree(batch: AxisSelector, individual_datums: list[X], *, pad_to_batch_ def _stack_leaves_unchecked(*leaves): if is_named_array(leaves[0]): - return hax.stack(batch.name, leaves + [hax.zeros_like(leaves[0]) for _ in range(missing_count)]) + return hax.stack(batch.name, leaves + tuple(hax.zeros_like(leaves[0]) for _ in range(missing_count))) else: - return jnp.stack(leaves + [jnp.zeros_like(leaves[0]) for _ in range(missing_count)]) + return jnp.stack(leaves + tuple(jnp.zeros_like(leaves[0]) for _ in range(missing_count))) + else: + def _stack_leaves_unchecked(*leaves): if is_named_array(leaves[0]): return hax.stack(hax.axis_name(batch), leaves) From d4bf9e2df29c3cb0a710d29eecbfebb9116b5af5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 22 May 2024 22:01:53 -0700 Subject: [PATCH 03/34] don't require imports --- src/levanter/eval_harness.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 0e64864e6..1c3d407af 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -15,9 +15,18 @@ import jax.numpy as jnp import transformers from jax.experimental.multihost_utils import broadcast_one_to_all -from lm_eval import evaluator, tasks -from lm_eval.api.instance import Instance -from lm_eval.api.model import LM + + +try: + from lm_eval import evaluator, tasks + from lm_eval.api.instance import Instance + from lm_eval.api.model import LM +except ImportError: + LM = object + Instance = object + evaluator = object + tasks = object + from tqdm import tqdm import haliax as hax From a32308aea51493faa3969244f12af3300c690420 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 27 May 2024 21:31:09 -0700 Subject: [PATCH 04/34] wio --- src/levanter/eval_harness.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 1c3d407af..5da76b130 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -66,14 +66,14 @@ # - loop -class RequestType: +class _RequestType: LOG_LIKELIHOOD = 0 GENERATE_UNTIL = 1 LOG_LIKELIHOOD_ROLLING = 2 FINISHED = 3 -class InternalLMAdaptor: +class _InternalLMAdaptor: def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources): self.EvalBatch = EvalBatch self.model = model @@ -104,11 +104,11 @@ def loglikelihood(self, examples: LmExample) -> tuple[hax.NamedArray, hax.NamedA return self._jit_loglikelihood(self.model, examples) def do_request(self, request_type, example): - if request_type == RequestType.LOG_LIKELIHOOD: + if request_type == _RequestType.LOG_LIKELIHOOD: return self.loglikelihood(example) - elif request_type == RequestType.GENERATE_UNTIL: + elif request_type == _RequestType.GENERATE_UNTIL: raise NotImplementedError() - elif request_type == RequestType.LOG_LIKELIHOOD_ROLLING: + elif request_type == _RequestType.LOG_LIKELIHOOD_ROLLING: raise NotImplementedError() else: raise ValueError(f"Invalid request type {request_type}") @@ -116,20 +116,20 @@ def do_request(self, request_type, example): def worker_loop(self): dummy_example = self.dummy_example while True: - request_type, request = broadcast_one_to_all((RequestType.FINISHED, dummy_example), is_source=False) + request_type, request = broadcast_one_to_all((_RequestType.FINISHED, dummy_example), is_source=False) - if request_type == RequestType.FINISHED: + if request_type == _RequestType.FINISHED: break self.do_request(request_type, request) def finish(self): assert jax.process_index() == 0 - broadcast_one_to_all((RequestType.FINISHED, self.dummy_example), is_source=True) + broadcast_one_to_all((_RequestType.FINISHED, self.dummy_example), is_source=True) class LevanterHarnessLM(LM): - def __init__(self, adaptor: InternalLMAdaptor, tokenizer): + def __init__(self, adaptor: _InternalLMAdaptor, tokenizer): super().__init__() self.adaptor = adaptor self.tokenizer = tokenizer @@ -173,7 +173,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: result: list[tuple[float, bool]] = [] for batch in batched(tqdm(examples, desc="examples", leave=False), self.adaptor.EvalBatch.size): batch_example = self._stack_batch(batch) - out_lls, out_correct = self._dispatch(RequestType.LOG_LIKELIHOOD, batch_example) + out_lls, out_correct = self._dispatch(_RequestType.LOG_LIKELIHOOD, batch_example) result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # skip padding @@ -211,7 +211,7 @@ def generate_until(self, requests) -> List[str]: def run_lm_eval_harness(task_spec: list[str], model, tokenizer, EvalBatch, axis_resources, max_examples=None): - adaptor = InternalLMAdaptor(EvalBatch, model, axis_resources) + adaptor = _InternalLMAdaptor(EvalBatch, model, axis_resources) harness = LevanterHarnessLM(adaptor, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) if jax.process_index() == 0: From 319eb6acb3bc2e849d805bcb64041db461b4a06b Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 28 May 2024 10:47:35 -0700 Subject: [PATCH 05/34] wip --- src/levanter/callbacks.py | 37 ++++++++++++++++++++++++++++ src/levanter/eval_harness.py | 46 ++++++++++++++++------------------- src/levanter/main/train_lm.py | 13 +++++++++- src/levanter/trainer.py | 2 +- 4 files changed, 71 insertions(+), 27 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index bf7878ed4..4f966c650 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -15,11 +15,13 @@ from tqdm import tqdm import levanter.tracker +from levanter.eval_harness import LmEvalHarnessConfig from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils.jax_utils import barrier_sync, jnp_to_python +from levanter.utils.tree_utils import inference_mode from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -333,3 +335,38 @@ def compute_and_viz_log_probs(step: StepInfo): wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs + + +def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): + from levanter.eval_harness import run_lm_eval_harness + + def lm_eval_harness(step: StepInfo, force=False): + if step.step == 0 and not force: + return # don't run eval on the first step + + model = inference_mode(step.model, True) + outputs = run_lm_eval_harness(model, config.task_spec_or_default(), tokenizer, EvalBatch, axis_resources, max_examples=config.max_examples) + + if jax.process_index() == 0: + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: + import json + + json.dump(outputs, f) + levanter.tracker.current_tracker().log_artifact(f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output") + + # also log accuracy statistics etc + metrics_to_log = {} + for task, metrics in outputs["results"]: + for metric, value in metrics.items(): + if metric.endswith(",none"): + metric = metric[:-len(",none")] + + if metric != "alias": + # levanter.tracker.log_metrics({f"lm_eval/{task}/{metric}": value}, step=step.step) + metrics_to_log[f"lm_eval/{task}/{metric}"] = value + + + levanter.tracker.log_metrics(metrics_to_log, step=step.step) + + + return lm_eval_harness diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 5da76b130..cf0aa9bf1 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -210,7 +210,7 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() -def run_lm_eval_harness(task_spec: list[str], model, tokenizer, EvalBatch, axis_resources, max_examples=None): +def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_resources, max_examples=None) -> dict: adaptor = _InternalLMAdaptor(EvalBatch, model, axis_resources) harness = LevanterHarnessLM(adaptor, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) @@ -231,6 +231,23 @@ class LmEvalHarnessConfig: task_spec: Optional[list[str]] = None max_examples: Optional[int] = None + def task_spec_or_default(self): + return self.task_spec or [ + # "lambada", + # "piqa", + "hellaswag", + # "winogrande", + # "mathqa", + # "pubmedqa", + # "boolq", + # "cb", + # "copa", + # "multirc", + # "record", + # "wic", + # "wsc", + ] + @dataclass(frozen=True) class EvalHarnessConfig: @@ -254,22 +271,7 @@ def run_eval_harness_main(config: EvalHarnessConfig): config.trainer.initialize() tokenizer = config.the_tokenizer - task_spec = config.eval_harness.task_spec or [ - # "lambada", - # "piqa", - "hellaswag", - # "winogrande", - # "mathqa", - # "pubmedqa", - # "boolq", - # "cb", - # "copa", - # "multirc", - # "record", - # "wic", - # "wsc", - ] - + task_spec = config.eval_harness.task_spec_or_default() max_examples = config.eval_harness.max_examples compute_axis_mapping = config.trainer.compute_axis_mapping @@ -295,14 +297,8 @@ def run_eval_harness_main(config: EvalHarnessConfig): model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) logger.info("Running LM eval harness....") - outputs = run_lm_eval_harness( - task_spec, - model, - tokenizer, - config.EvalBatch, - axis_resources=compute_axis_mapping, - max_examples=max_examples, - ) + outputs = run_lm_eval_harness(model, task_spec, tokenizer, config.EvalBatch, + axis_resources=compute_axis_mapping, max_examples=max_examples) logger.info("Finished running LM eval harness") # log the results as json diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 57992a956..54a37ab4e 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -15,6 +15,7 @@ from levanter import callbacks from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig +from levanter.eval_harness import LmEvalHarnessConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig from levanter.optim import AdamConfig, OptimizerConfig @@ -46,7 +47,10 @@ class TrainLmConfig: hf_upload: Optional[str] = None hf_save_steps: int = 10000 - update_hessian_steps: int = 10 + eval_harness: Optional[LmEvalHarnessConfig] = None + eval_harness_steps: int = 10000 + num_eval_harness_samples: Optional[int] = None + def main(config: TrainLmConfig): @@ -163,6 +167,13 @@ def main(config: TrainLmConfig): every=config.hf_save_steps, ) + if config.eval_harness is not None: + eval_harness = config.eval_harness + trainer.add_hook( + callbacks.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), + every=config.eval_harness_steps, + ) + # visualize log probs @named_jit( in_axis_resources=parameter_axis_mapping, diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 067309f09..efd2e8d39 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -592,7 +592,7 @@ class TrainerConfig: @property def TrainBatch(self): - if self.batch_size <= 0: + if self.train_batch_size <= 0: raise ValueError("batch_size must be positive. Did you call initialize?") return Axis("batch", self.train_batch_size) From f05b189b06481a4c234f6a73f01840e57475fa78 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 12 Jun 2024 10:16:36 -0700 Subject: [PATCH 06/34] almost there --- config/gpt2_nano_harness.yaml | 26 ++++++++++++++++++++++++++ src/levanter/eval.py | 1 + src/levanter/main/train_lm.py | 1 - 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 config/gpt2_nano_harness.yaml diff --git a/config/gpt2_nano_harness.yaml b/config/gpt2_nano_harness.yaml new file mode 100644 index 000000000..b79d6209a --- /dev/null +++ b/config/gpt2_nano_harness.yaml @@ -0,0 +1,26 @@ +eval_harness: + task_spec: ["lambada", "piqa", "hellaswag"] + max_examples: 32 +eval_harness_steps: 50 +data: + id: dlwh/wikitext_103_detokenized +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_parallelism: -1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" diff --git a/src/levanter/eval.py b/src/levanter/eval.py index ceaf9e073..0b2e066d5 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -201,6 +201,7 @@ def evaluate(self, m: LmHeadModel): mean_losses_per_tag = hax.zeros(self.dataset.Tag, dtype=np.float32) state = (RunningMean.zeros_like(total_loss), RunningMean.zeros_like(mean_losses_per_tag)) + state = hax.shard(state) for batch, tags in tqdm.tqdm(self.loader, "eval"): state = self.accum_for_batch(m, state, batch, tags) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index c71bf7b12..4d4efabd0 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -50,7 +50,6 @@ class TrainLmConfig: data_seed: Optional[int] = None # if provided, will override the data seed from the trainer eval_harness: Optional[LmEvalHarnessConfig] = None eval_harness_steps: int = 10000 - num_eval_harness_samples: Optional[int] = None def main(config: TrainLmConfig): From 19ac049063f577ae6bf95f470a7a3d5875773e4b Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 12 Jun 2024 10:25:47 -0700 Subject: [PATCH 07/34] maybe we are there? --- src/levanter/callbacks.py | 19 +++++++++++++------ src/levanter/eval_harness.py | 22 ++++++++++++++-------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 4f966c650..766983b11 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -345,28 +345,35 @@ def lm_eval_harness(step: StepInfo, force=False): return # don't run eval on the first step model = inference_mode(step.model, True) - outputs = run_lm_eval_harness(model, config.task_spec_or_default(), tokenizer, EvalBatch, axis_resources, max_examples=config.max_examples) + outputs = run_lm_eval_harness( + model, + config.task_spec_or_default(), + tokenizer, + EvalBatch, + axis_resources, + max_examples=config.max_examples, + ) if jax.process_index() == 0: with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: import json json.dump(outputs, f) - levanter.tracker.current_tracker().log_artifact(f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output") + levanter.tracker.current_tracker().log_artifact( + f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output" + ) # also log accuracy statistics etc metrics_to_log = {} - for task, metrics in outputs["results"]: + for task, metrics in outputs["results"].items(): for metric, value in metrics.items(): if metric.endswith(",none"): - metric = metric[:-len(",none")] + metric = metric[: -len(",none")] if metric != "alias": # levanter.tracker.log_metrics({f"lm_eval/{task}/{metric}": value}, step=step.step) metrics_to_log[f"lm_eval/{task}/{metric}"] = value - levanter.tracker.log_metrics(metrics_to_log, step=step.step) - return lm_eval_harness diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index cf0aa9bf1..65f69e58e 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -73,7 +73,7 @@ class _RequestType: FINISHED = 3 -class _InternalLMAdaptor: +class _InternalLMAdapter: def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources): self.EvalBatch = EvalBatch self.model = model @@ -129,7 +129,7 @@ def finish(self): class LevanterHarnessLM(LM): - def __init__(self, adaptor: _InternalLMAdaptor, tokenizer): + def __init__(self, adaptor: _InternalLMAdapter, tokenizer): super().__init__() self.adaptor = adaptor self.tokenizer = tokenizer @@ -211,16 +211,16 @@ def generate_until(self, requests) -> List[str]: def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_resources, max_examples=None) -> dict: - adaptor = _InternalLMAdaptor(EvalBatch, model, axis_resources) - harness = LevanterHarnessLM(adaptor, tokenizer) + adapter = _InternalLMAdapter(EvalBatch, model, axis_resources) + harness = LevanterHarnessLM(adapter, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) if jax.process_index() == 0: try: outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) finally: - adaptor.finish() + adapter.finish() else: - adaptor.worker_loop() + adapter.worker_loop() outputs = {} return outputs @@ -297,8 +297,14 @@ def run_eval_harness_main(config: EvalHarnessConfig): model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) logger.info("Running LM eval harness....") - outputs = run_lm_eval_harness(model, task_spec, tokenizer, config.EvalBatch, - axis_resources=compute_axis_mapping, max_examples=max_examples) + outputs = run_lm_eval_harness( + model, + task_spec, + tokenizer, + config.EvalBatch, + axis_resources=compute_axis_mapping, + max_examples=max_examples, + ) logger.info("Finished running LM eval harness") # log the results as json From f5dec318bfbda922619cd507b11d4f724a047d18 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 12 Jun 2024 10:27:22 -0700 Subject: [PATCH 08/34] launcher --- config/gpt2_small_fast.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 6242a37bc..9eaa09a78 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -23,3 +23,7 @@ optimizer: learning_rate: 1E-3 weight_decay: 0.1 warmup: 0.01 +eval_harness: + task_spec: ["lambada", "piqa", "hellaswag"] + max_examples: 32 +eval_harness_steps: 1000 From 12c0b061a867e03d288714a436b57cbeecc059b3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 12 Jun 2024 14:44:44 -0700 Subject: [PATCH 09/34] fix (?) logging of loading time etc. --- src/levanter/logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 78588669f..86301b4ef 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -80,7 +80,8 @@ def fn(): return time.perf_counter() - start yield fn - end = time.time() + end = time.perf_counter() + done = True def silence_transformer_nag(): From 463331e1428f67b0dbc04395fe79b102f77eefc9 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 12 Jun 2024 21:41:23 -0700 Subject: [PATCH 10/34] wip --- config/gpt2_nano_harness.yaml | 2 +- config/olmo/olmo_7b_repro.yaml | 177 ++++++++++++++++++++++++++ pyproject.toml | 2 +- scripts/launch_gpt2_small_fast_tpu.sh | 18 +-- src/levanter/eval_harness.py | 30 +++-- src/levanter/trainer.py | 16 ++- 6 files changed, 212 insertions(+), 33 deletions(-) create mode 100644 config/olmo/olmo_7b_repro.yaml diff --git a/config/gpt2_nano_harness.yaml b/config/gpt2_nano_harness.yaml index b79d6209a..5e0a0a36f 100644 --- a/config/gpt2_nano_harness.yaml +++ b/config/gpt2_nano_harness.yaml @@ -1,5 +1,5 @@ eval_harness: - task_spec: ["lambada", "piqa", "hellaswag"] + task_spec: ["piqa", "hellaswag"] max_examples: 32 eval_harness_steps: 50 data: diff --git a/config/olmo/olmo_7b_repro.yaml b/config/olmo/olmo_7b_repro.yaml new file mode 100644 index 000000000..645426012 --- /dev/null +++ b/config/olmo/olmo_7b_repro.yaml @@ -0,0 +1,177 @@ +#data: !include data/dolma_olmo_paloma.yaml +data: + cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7" + tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo` + # tokenizer: "meta-llama/Llama-2-7b-hf" + stop_strategy: restart + shuffle_buffer_size: 100000 + configs: + dolma-algebraic-stack: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz + dolma-arxiv: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz + dolma-gutenberg: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz + dolma-c4: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz + dolma-cc: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz + dolma-cc-news: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz + dolma-falcon: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz + dolma-megawika: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz + dolma-owmath: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz + dolma-pes2o: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz + dolma-reddit: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz + dolma-stackexchange: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz + dolma-starcoder: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz + dolma-flan: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz + dolma-wiki: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz + # these are just for eval + "paloma/4chan": + validation_urls: + - gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz + "paloma/c4_100_domains": + validation_urls: + - gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz + "paloma/c4_en": + validation_urls: + - gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz + "paloma/dolma-v1_5": + validation_urls: + - gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz + "paloma/dolma_100_programing_languages": + validation_urls: + - gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz + "paloma/dolma_100_subreddits": + validation_urls: + - gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz + "paloma/falcon-refinedweb": + validation_urls: + - gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz + "paloma/gab": + validation_urls: + - gs://levanter-data/paloma/gab/val/val*.jsonl.gz + "paloma/m2d2_s2orc_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz + "paloma/m2d2_wikipedia_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz + "paloma/manosphere_meta_sep": + validation_urls: + - gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz + "paloma/mc4": + validation_urls: + - gs://levanter-data/paloma/mc4/val/val*.jsonl.gz + "paloma/ptb": + validation_urls: + - gs://levanter-data/paloma/ptb/val/val*.jsonl.gz + "paloma/redpajama": + validation_urls: + - gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz + "paloma/twitterAAE_HELM_fixed": + validation_urls: + - gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz + "paloma/wikitext_103": + validation_urls: + - gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz + train_weights: + # sampling proportion comes from https://huggingface.co/datasets/allenai/dolma + dolma-algebraic-stack: 12.6 # 12.6 * 1.0 + dolma-arxiv: 28.0 # 28.0 * 1.0 + dolma-gutenberg: 5.3 # 5.3 * 1.0 + dolma-c4: 69.2 # 138.4 * 0.5 + dolma-cc: 597.75 # 1,195.5 * 0.5 + dolma-cc-news: 14.3 # 1.0 + dolma-falcon: 456.4 # 1.0, refined web + dolma-megawika: 4.6 # 1.0 + dolma-owmath: 12.6 # 1.0 + dolma-pes2o: 57.2 # 1.0 + dolma-reddit: 79.9 # 1.0 + dolma-stackexchange: 19.6 # 1.0 + dolma-starcoder: 263.8 # 1.0 + dolma-flan: 16.5 # 6.5 * 1.0 + dolma-wiki: 7.4 # 3.7 * 2.0 + paloma/4chan: 0.0 + paloma/c4_100_domains: 0.0 + paloma/c4_en: 0.0 + paloma/dolma-v1_5: 0.0 + paloma/dolma_100_programing_languages: 0.0 + paloma/dolma_100_subreddits: 0.0 + paloma/falcon-refinedweb: 0.0 + paloma/gab: 0.0 + paloma/m2d2_s2orc_unsplit: 0.0 + paloma/m2d2_wikipedia_unsplit: 0.0 + paloma/manosphere_meta_sep: 0.0 + paloma/mc4: 0.0 + paloma/ptb: 0.0 + paloma/redpajama: 0.0 + paloma/twitterAAE_HELM_fixed: 0.0 + paloma/wikitext_103: 0.0 +model: # 7B class model + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True +# flash_attention_block_size: 1024 + + use_bias: false + use_layer_norm_weight: false +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 2048 # olmo actually uses 2160 table 5 of https://arxiv.org/pdf/2402.00838 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + replica_dcn_axis_size: 2 +optimizer: + learning_rate: 3E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + beta1: 0.9 + beta2: 0.95 + warmup: 2000 + # NB olmo uses linear decay + decay: linear diff --git a/pyproject.toml b/pyproject.toml index d8d043069..42698f463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ # "haliax>=1.3,<2.0", # Haliax changes in step with levanter, so we'll just use the git version except for releases. # "haliax @ git+https://github.com/stanford-crfm/haliax.git@main", - "haliax>=1.4.dev301", + "haliax>=1.4.dev302", "equinox>=0.11.4", "jaxtyping>=0.2.20", "tokenizers>=0.15.2", diff --git a/scripts/launch_gpt2_small_fast_tpu.sh b/scripts/launch_gpt2_small_fast_tpu.sh index ed64660f7..6366b63fc 100644 --- a/scripts/launch_gpt2_small_fast_tpu.sh +++ b/scripts/launch_gpt2_small_fast_tpu.sh @@ -1,18 +1,6 @@ # Launches the "gpt_small_fast" model on a TPU node -if [ -z "$WANDB_API_KEY" ]; then - echo "Error: WANDB_API_KEY not set" - exit 1 -fi - -if [ -z "$GIT_BRANCH" ]; then - GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD) -fi - -echo "Launching GPT2 small fast on TPU with git branch $GIT_BRANCH" - -bash infra/babysit-tpu-vm.sh levanter-itest-32 -p -z us-east1-d -t v3-32 -b $GIT_BRANCH -- \ - XLA_FLAGS="--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*" \ - WANDB_API_KEY=$WANDB_API_KEY levanter/infra/run.sh python levanter/src/levanter/main/train_lm.py \ - --config_path levanter/config/gpt2_small_fast.yaml \ +python infra/launch.py --tpu_name levanter-itest-32 --preemptible --zone us-east1-d --tpu_type v3-32 --foreground -- \ + python src/levanter/main/train_lm.py \ + --config_path config/gpt2_small_fast.yaml \ --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 65f69e58e..8c4b231a7 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -38,7 +38,7 @@ from levanter.data import batched from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import TrainerConfig -from levanter.utils.jax_utils import stack_tree, use_cpu_device +from levanter.utils.jax_utils import local_cpu_mesh, stack_tree, use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -79,9 +79,9 @@ def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources): self.model = model self.axis_resources = axis_resources - self.dummy_example: LmExample = LmExample.causal( + self.dummy_example: LmExample = hax.named_jit(lambda: LmExample.causal( tokens=hax.zeros(model.Pos, dtype=jnp.int32), - ) + ))() def _eval_loglikelihood(model: LmHeadModel, example: LmExample): logits = model(example.tokens) @@ -116,13 +116,20 @@ def do_request(self, request_type, example): def worker_loop(self): dummy_example = self.dummy_example while True: + logger.info("Waiting for request") request_type, request = broadcast_one_to_all((_RequestType.FINISHED, dummy_example), is_source=False) if request_type == _RequestType.FINISHED: break + logger.info("Doing request") self.do_request(request_type, request) + def dispatch(self, request_type, request): + print(self.dummy_example, request) + broadcast_one_to_all((request_type, request), is_source=True) + return self.do_request(request_type, request) + def finish(self): assert jax.process_index() == 0 broadcast_one_to_all((_RequestType.FINISHED, self.dummy_example), is_source=True) @@ -161,19 +168,23 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: completions = self.tokenizer([req.args[1] for req in requests])["input_ids"] examples: list[LmExample] = [] - with use_cpu_device(): + logger.info("Creating examples") + with local_cpu_mesh(): for context, completion in zip(contexts, completions): context, completion = self._truncate_or_pad(context, completion) context, completion = hax.named(context, Pos.name), hax.named(completion, Pos.name) - example = LmExample.from_prompt_and_completion( + example = jax.jit(lambda: LmExample.from_prompt_and_completion( context, completion, ignore_id=self.tokenizer.pad_token_id - ) + ))() examples.append(example) + logger.info("Created examples") result: list[tuple[float, bool]] = [] for batch in batched(tqdm(examples, desc="examples", leave=False), self.adaptor.EvalBatch.size): - batch_example = self._stack_batch(batch) - out_lls, out_correct = self._dispatch(_RequestType.LOG_LIKELIHOOD, batch_example) + logger.info("Processing batch") + with local_cpu_mesh(): + batch_example = self._stack_batch(batch) + out_lls, out_correct = self.adaptor.dispatch(_RequestType.LOG_LIKELIHOOD, batch_example) result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # skip padding @@ -199,9 +210,6 @@ def _truncate_or_pad(self, context, completion): return context, completion - def _dispatch(self, request_type, request): - broadcast_one_to_all((request_type, request), is_source=True) - return self.adaptor.do_request(request_type, request) def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index da5c0de60..f874662a0 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -557,11 +557,11 @@ class TrainerConfig: ) # overrides axis_mapping for parameter """logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred""" - """Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings https://cloud.google.com/tpu/docs/multislice-introduction""" - replica_ici_axis_size: int = 1 - model_axis_size: int = 1 - """how many devices within each slice for sharding with DP. Fix TP=1, the rest of the devices is for FSDP.""" - replica_dcn_axis_size: int = 1 + # Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings https://cloud.google.com/tpu/docs/multislice-introduction + replica_ici_axis_size: int = 1 # how many parameter replicas there should be "within" each slice (ICI) + model_axis_size: int = 1 # axis size for tensor parallelism (TP) + replica_dcn_axis_size: Optional[int] = None # how many parameter replicas there should be "across" slices (DCN) + auto_replicas: bool = True # whether to automatically set replica_dcn_axis_size based on num_slices """how many slices in the multislice scheme for sharding with DP and TP. The rest of the devices is for FSDP.""" # Config related to batch sizes @@ -767,6 +767,12 @@ def _validate_and_set_defaults(self): ): raise ValueError("either model_axis_size or local_device_count must be divisible by the other") + # handle replica_dcn_axis_size + if self.auto_replicas and self.replica_dcn_axis_size is None: + if self.num_slices > 1: + logger.info(f"Setting replica_dcn_axis_size to {self.num_slices}") + self.replica_dcn_axis_size = self.num_slices + assert self.train_batch_size != -1 or self.per_device_parallelism != -1 if self.per_device_parallelism == -1: From 2d4d0d8c235eaa50ad90236e9896a11dcf820546 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 20 Jun 2024 22:40:14 -0700 Subject: [PATCH 11/34] wip --- config/harness/harness_nano.yaml | 3 + config/olmo/olmo_7b_repro.yaml | 2 - src/levanter/eval_harness.py | 175 ++++++++++++------------------- src/levanter/main/eval_lm.py | 4 +- src/levanter/models/attention.py | 8 +- src/levanter/models/lm_model.py | 13 +-- src/levanter/trainer.py | 13 ++- src/levanter/utils/jax_utils.py | 7 +- 8 files changed, 93 insertions(+), 132 deletions(-) diff --git a/config/harness/harness_nano.yaml b/config/harness/harness_nano.yaml index c34e2a83a..833291e5c 100644 --- a/config/harness/harness_nano.yaml +++ b/config/harness/harness_nano.yaml @@ -1,3 +1,5 @@ +eval_harness: + task_spec: ["hellaswag"] tokenizer: "gpt2" model: type: gpt2 @@ -7,6 +9,7 @@ model: trainer: mp: f32 num_train_steps: 100 + profiler: true checkpointer: keep: diff --git a/config/olmo/olmo_7b_repro.yaml b/config/olmo/olmo_7b_repro.yaml index 645426012..aca11c419 100644 --- a/config/olmo/olmo_7b_repro.yaml +++ b/config/olmo/olmo_7b_repro.yaml @@ -173,5 +173,3 @@ optimizer: beta1: 0.9 beta2: 0.95 warmup: 2000 - # NB olmo uses linear decay - decay: linear diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 8c4b231a7..e5517b16f 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -16,6 +16,8 @@ import transformers from jax.experimental.multihost_utils import broadcast_one_to_all +from levanter.compat.hf_checkpoints import HFCheckpointConverter +from levanter.models.gpt2 import Gpt2Config try: from lm_eval import evaluator, tasks @@ -25,7 +27,7 @@ LM = object Instance = object evaluator = object - tasks = object + # tasks = object from tqdm import tqdm @@ -33,6 +35,8 @@ from haliax.nn import cross_entropy_loss from haliax.partitioning import round_axis_for_partitioning +from jax_sourceror import sourcerize + import levanter.config from levanter.checkpoint import load_checkpoint from levanter.data import batched @@ -73,15 +77,13 @@ class _RequestType: FINISHED = 3 -class _InternalLMAdapter: - def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources): +class LevanterHarnessLM(LM): + def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): + super().__init__() self.EvalBatch = EvalBatch self.model = model self.axis_resources = axis_resources - - self.dummy_example: LmExample = hax.named_jit(lambda: LmExample.causal( - tokens=hax.zeros(model.Pos, dtype=jnp.int32), - ))() + self.tokenizer = tokenizer def _eval_loglikelihood(model: LmHeadModel, example: LmExample): logits = model(example.tokens) @@ -100,55 +102,8 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample): _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} ) - def loglikelihood(self, examples: LmExample) -> tuple[hax.NamedArray, hax.NamedArray]: - return self._jit_loglikelihood(self.model, examples) - - def do_request(self, request_type, example): - if request_type == _RequestType.LOG_LIKELIHOOD: - return self.loglikelihood(example) - elif request_type == _RequestType.GENERATE_UNTIL: - raise NotImplementedError() - elif request_type == _RequestType.LOG_LIKELIHOOD_ROLLING: - raise NotImplementedError() - else: - raise ValueError(f"Invalid request type {request_type}") - - def worker_loop(self): - dummy_example = self.dummy_example - while True: - logger.info("Waiting for request") - request_type, request = broadcast_one_to_all((_RequestType.FINISHED, dummy_example), is_source=False) - - if request_type == _RequestType.FINISHED: - break - - logger.info("Doing request") - self.do_request(request_type, request) - - def dispatch(self, request_type, request): - print(self.dummy_example, request) - broadcast_one_to_all((request_type, request), is_source=True) - return self.do_request(request_type, request) - - def finish(self): - assert jax.process_index() == 0 - broadcast_one_to_all((_RequestType.FINISHED, self.dummy_example), is_source=True) - - -class LevanterHarnessLM(LM): - def __init__(self, adaptor: _InternalLMAdapter, tokenizer): - super().__init__() - self.adaptor = adaptor - self.tokenizer = tokenizer - - def _stack_batch(examples): - return stack_tree(self.adaptor.EvalBatch, examples, pad_to_batch_size=True) - - self._stack_batch_jit = hax.named_jit(_stack_batch) - def _stack_batch(self, examples): - with use_cpu_device(): - return self._stack_batch_jit(examples) + return stack_tree(self.EvalBatch, examples, pad_to_batch_size=True) def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ @@ -162,41 +117,43 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ - Pos = self.adaptor.model.Pos + Pos = self.model.Pos contexts = self.tokenizer([req.args[0] for req in requests])["input_ids"] completions = self.tokenizer([req.args[1] for req in requests])["input_ids"] examples: list[LmExample] = [] - logger.info("Creating examples") - with local_cpu_mesh(): - for context, completion in zip(contexts, completions): - context, completion = self._truncate_or_pad(context, completion) - context, completion = hax.named(context, Pos.name), hax.named(completion, Pos.name) - example = jax.jit(lambda: LmExample.from_prompt_and_completion( - context, completion, ignore_id=self.tokenizer.pad_token_id - ))() - examples.append(example) - logger.info("Created examples") + + @hax.named_jit + def _jit_create_example(tokens, prompt_len): + tokens = hax.named(tokens, self.model.Pos) + return LmExample.from_prompt_and_completion(self.model.Pos, tokens, prompt_len, ignore_id=self.tokenizer.pad_token_id) + + # TODO: offload this to an evalbatchloader + for context, completion in zip(tqdm(contexts, desc="Creating examples"), completions): + tokens, length = self._truncate_or_pad(context, completion) + tokens = jnp.array(tokens) + length = jnp.array(length) + example = _jit_create_example(tokens, length) + examples.append(example) result: list[tuple[float, bool]] = [] - for batch in batched(tqdm(examples, desc="examples", leave=False), self.adaptor.EvalBatch.size): + for batch in batched(tqdm(examples, desc="examples", leave=False), self.EvalBatch.size): logger.info("Processing batch") - with local_cpu_mesh(): - batch_example = self._stack_batch(batch) - out_lls, out_correct = self.adaptor.dispatch(_RequestType.LOG_LIKELIHOOD, batch_example) + batch_example = self._stack_batch(batch) + # batch_example = jax.device_put(batch_example, jax.local_devices()[0]) + source = sourcerize(self._jit_loglikelihood)(self.model, batch_example) + print(source, flush=True) + out_lls, out_correct = self._jit_loglikelihood(self.model, batch_example) result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # skip padding result = result[: len(examples)] - # print(contexts) - # print(completions) - return result def _truncate_or_pad(self, context, completion): - max_len = self.adaptor.model.Pos.size + max_len = self.model.Pos.size if len(completion) > max_len: warnings.warn(f"Completion is longer than max length {max_len}. Truncating.") completion = completion[:max_len] @@ -208,8 +165,7 @@ def _truncate_or_pad(self, context, completion): # right pad with padding token context = context + [pad_token_id] * (max_len - len(context) - len(completion)) - return context, completion - + return jnp.array(context + completion), len(context) def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() @@ -219,17 +175,9 @@ def generate_until(self, requests) -> List[str]: def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_resources, max_examples=None) -> dict: - adapter = _InternalLMAdapter(EvalBatch, model, axis_resources) - harness = LevanterHarnessLM(adapter, tokenizer) + harness = LevanterHarnessLM(EvalBatch, model, axis_resources, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) - if jax.process_index() == 0: - try: - outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) - finally: - adapter.finish() - else: - adapter.worker_loop() - outputs = {} + outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) return outputs @@ -261,8 +209,9 @@ def task_spec_or_default(self): class EvalHarnessConfig: tokenizer: str checkpoint_path: str + checkpoint_is_hf: bool = False trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig) - model: LmConfig = dataclasses.field(default_factory=LmConfig) + model: LmConfig = dataclasses.field(default_factory=Gpt2Config) eval_harness: LmEvalHarnessConfig = dataclasses.field(default_factory=LmEvalHarnessConfig) @@ -294,33 +243,41 @@ def run_eval_harness_main(config: EvalHarnessConfig): logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") # initialize the model - if config.checkpoint_path is not None: - # initialize the model + if config.checkpoint_is_hf: + model_config = config.model + converter: HFCheckpointConverter = model_config.default_hf_checkpoint_converter # type: ignore + converter = converter.replaced(reference_checkpoint=config.checkpoint_path, tokenizer=tokenizer) + model = converter.load_pretrained( + model_config.model_type, config.checkpoint_path, dtype=config.trainer.mp.compute_dtype + ) + else: with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) - # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model model = load_checkpoint(model, config.checkpoint_path, subpath="model") - model = inference_mode(model, True) - - model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) - - logger.info("Running LM eval harness....") - outputs = run_lm_eval_harness( - model, - task_spec, - tokenizer, - config.EvalBatch, - axis_resources=compute_axis_mapping, - max_examples=max_examples, - ) + model = hax.shard(model, parameter_axis_mapping) - logger.info("Finished running LM eval harness") - # log the results as json - with open("lm_eval_results.json", "w") as f: + model = inference_mode(model, True) - json.dump(outputs, f, indent=2) - else: - raise ValueError("No checkpoint path provided") + ex = LmExample.from_prompt_and_completion(model.Pos, hax.zeros(model.Pos, dtype=int), 100) + source = sourcerize(model.compute_loss)(ex) + + print(source, flush=True) + + logger.info("Running LM eval harness....") + outputs = run_lm_eval_harness( + model, + task_spec, + tokenizer, + config.EvalBatch, + axis_resources=compute_axis_mapping, + max_examples=max_examples, + ) + + logger.info("Finished running LM eval harness") + # log the results as json + with open("lm_eval_results.json", "w") as f: + + json.dump(outputs, f, indent=2) if __name__ == "__main__": diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 4137c6b25..0c4f326b6 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -97,9 +97,9 @@ def compute_loss(model: LmHeadModel, example: LmExample): if config.hf_checkpoint is not None: # load the huggingface model model_config = config.model - if not hasattr(model_config, "hf_checkpoint_converter"): + if not hasattr(model_config, "default_hf_checkpoint_converter"): raise ValueError("Model config does not have an HF checkpoint converter. Can't load HF checkpoint.") - converter: HFCheckpointConverter = model_config.hf_checkpoint_converter + converter: HFCheckpointConverter = model_config.default_hf_checkpoint_converter converter = converter.replaced(reference_checkpoint=config.hf_checkpoint, tokenizer=tokenizer) model_from_hf_checkpoint = converter.load_pretrained( model_config.model_type, config.hf_checkpoint, dtype=mp.compute_dtype diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index e7c94f50b..1132018e0 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -848,13 +848,15 @@ def wrap_flash_attention(q, k, v): ) if mask is None: - kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + masks = [splash_attention_mask.FullMask(_shape=(Sq, Sk)) for i in range(Hq)] + kernel_mask = splash_attention_mask.MultiHeadMask(masks=masks) elif isinstance(mask, AttentionMask): if mask.is_causal: masks = [splash_attention_mask.CausalMask(shape=(Sq, Sq)) for i in range(Hq)] - kernel_mask = splash_attention_mask.MultiHeadMask(masks=masks) else: - kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + masks = [splash_attention_mask.FullMask(_shape=(Sq, Sk)) for i in range(Hq)] + + kernel_mask = splash_attention_mask.MultiHeadMask(masks=masks) if mask.explicit_mask is not None: raise NotImplementedError("Explicit masks are not yet supported for splash attention") diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index d42a560ed..f52c43845 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -7,7 +7,7 @@ from jax.random import PRNGKey import haliax as hax -from haliax import Axis, NamedArray +from haliax import Axis, NamedArray, NamedOrNumeric from haliax.nn import cross_entropy_loss from levanter.models.attention import AttentionMask @@ -48,16 +48,13 @@ def causal( @staticmethod def from_prompt_and_completion( - prompt: hax.NamedArray, completion: hax.NamedArray, *, ignore_id: Optional[int] = None, all_causal: bool = True + Pos, + tokens: hax.NamedArray, prompt_length: NamedOrNumeric, *, ignore_id: Optional[int] = None, all_causal: bool = True ) -> "LmExample": - prompt_Pos = prompt.axes[0] - tokens = hax.concatenate(prompt_Pos.name, [prompt, completion]) - tokens_Pos = tokens.resolve_axis(prompt_Pos.name) - # mask out the prompt tokens - loss_mask = hax.arange(tokens_Pos) >= prompt_Pos.size + loss_mask = hax.arange(Pos) >= prompt_length # also mask out the last token - loss_mask *= 1 - hax.nn.one_hot(-1, tokens_Pos, dtype=jnp.float32) + loss_mask *= 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) if ignore_id is not None: ignore_mask = tokens != ignore_id diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index f874662a0..c076b2739 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -652,7 +652,7 @@ def device_mesh(self) -> Mesh: self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size, - self.replica_dcn_axis_size, + self.replica_dcn_axis_size, # type: ignore self.data_dcn_axis_size, ) @@ -768,10 +768,13 @@ def _validate_and_set_defaults(self): raise ValueError("either model_axis_size or local_device_count must be divisible by the other") # handle replica_dcn_axis_size - if self.auto_replicas and self.replica_dcn_axis_size is None: - if self.num_slices > 1: - logger.info(f"Setting replica_dcn_axis_size to {self.num_slices}") - self.replica_dcn_axis_size = self.num_slices + if self.replica_dcn_axis_size is None: + if self.auto_replicas: + if self.num_slices > 1: + logger.info(f"Setting replica_dcn_axis_size to {self.num_slices}") + self.replica_dcn_axis_size = self.num_slices + else: + self.replica_dcn_axis_size = 1 assert self.train_batch_size != -1 or self.per_device_parallelism != -1 diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 2fcfd4462..c7ae84564 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -1,5 +1,4 @@ import contextlib -import functools import json import warnings from dataclasses import fields @@ -43,7 +42,9 @@ def use_cpu_device(): def local_cpu_mesh(): """Temporarily sets the default device to CPU""" cpu = jax.local_devices(backend="cpu")[0] - mesh = jax.sharding.Mesh(np.array([cpu]).reshape(1, 1), ("data", "model")) + mesh = jax.sharding.Mesh( + np.array([cpu]).reshape(1, 1, 1), (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL) + ) with use_cpu_device(), mesh: yield mesh @@ -335,7 +336,7 @@ def estimated_free_device_memory(device) -> Optional[float]: return (limit - in_use) // (1024.0**3) -@functools.partial(jax.jit, static_argnums=(0), static_argnames=("batch", "pad_to_batch_size")) +# @functools.partial(jax.jit, static_argnums=(0), static_argnames=("batch", "pad_to_batch_size")) def stack_tree(batch: AxisSelector, individual_datums: list[X], *, pad_to_batch_size: bool) -> X: """ Stacks a tree of NamedArrays or arrays into a single array. NamedArrays get a new axis with the name batch_name, From c7c5f703639afa836cfd0f8209f362b719deb5ae Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 15 Nov 2024 15:52:58 -0800 Subject: [PATCH 12/34] off by one --- config/gpt2_nano_harness.yaml | 2 +- src/levanter/models/lm_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/config/gpt2_nano_harness.yaml b/config/gpt2_nano_harness.yaml index 5e0a0a36f..8a241a058 100644 --- a/config/gpt2_nano_harness.yaml +++ b/config/gpt2_nano_harness.yaml @@ -19,7 +19,7 @@ trainer: save_interval: 5m per_device_parallelism: -1 - train_batch_size: 32 + train_batch_size: 4 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 6f998deac..7c6c29f32 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -57,7 +57,7 @@ def from_prompt_and_completion( all_causal: bool = True, ) -> "LmExample": # mask out the prompt tokens - loss_mask = hax.arange(Pos) >= prompt_length + loss_mask = hax.arange(Pos) >= prompt_length - 1 # also mask out the last token loss_mask *= 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) From eb441e3df915fab94306df3f0a4bf33bf39bc063 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:50:08 -0800 Subject: [PATCH 13/34] move logging and types to util to make python's module resolution happierq --- src/levanter/__init__.py | 1 - src/levanter/callbacks.py | 2 +- src/levanter/checkpoint.py | 2 +- src/levanter/compat/hf_checkpoints.py | 2 +- src/levanter/data/audio.py | 2 +- src/levanter/data/text.py | 6 +++--- src/levanter/doremi.py | 2 +- src/levanter/eval.py | 2 +- src/levanter/lora.py | 2 +- src/levanter/main/cache_dataset.py | 2 +- src/levanter/models/backpack.py | 2 +- src/levanter/models/gemma.py | 4 ++-- src/levanter/models/gpt2.py | 2 +- src/levanter/models/llama.py | 4 ++-- src/levanter/models/mistral.py | 2 +- src/levanter/models/mpt.py | 2 +- src/levanter/models/qwen.py | 4 ++-- src/levanter/models/whisper.py | 2 +- src/levanter/trainer.py | 6 +++--- src/levanter/trainer_state.py | 2 +- src/levanter/utils/hf_utils.py | 2 +- src/levanter/{ => utils}/logging.py | 0 src/levanter/{ => utils}/types.py | 0 23 files changed, 27 insertions(+), 28 deletions(-) rename src/levanter/{ => utils}/logging.py (100%) rename src/levanter/{ => utils}/types.py (100%) diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index b969828bc..f9570aaf7 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,7 +3,6 @@ import levanter.data as data import levanter.distributed as distributed import levanter.eval as eval -import levanter.logging as logging import levanter.models as models import levanter.optim as optim import levanter.tracker as tracker diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 3906dd12b..cb8c016c0 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -20,12 +20,12 @@ import levanter.tracker from levanter.data import AsyncDataset, DataLoader from levanter.eval_harness import LmEvalHarnessConfig -from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils import flop_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python +from levanter.utils.logging import save_xla_dumps_to_wandb from levanter.utils.tree_utils import inference_mode from levanter.visualization import compute_and_visualize_log_probs as viz_probs diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 92733415a..db2395b4b 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -26,7 +26,7 @@ from haliax.jax_utils import is_in_jit, is_jax_array_like from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore -from levanter.types import FilterSpec +from levanter.utils.types import FilterSpec logger = logging.getLogger(__name__) diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 5822c3fba..f4ad33757 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -33,7 +33,6 @@ from haliax.partitioning import ResourceMapping from haliax.state_dict import from_torch_compatible_state_dict, save_state_dict, to_torch_compatible_state_dict -from levanter.logging import silence_transformer_nag from levanter.models.asr_model import ASRMixin from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.trainer import StepInfo @@ -41,6 +40,7 @@ from levanter.utils.cloud_utils import temp_dir_before_upload from levanter.utils.hf_utils import HfTokenizer from levanter.utils.jax_utils import best_effort_sharding, local_cpu_mesh, use_cpu_device +from levanter.utils.logging import silence_transformer_nag from levanter.utils.py_utils import dataclass_with_default_init, logical_cpu_memory_size diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 9bfc1e142..b2235e863 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -30,10 +30,10 @@ from levanter.data.text import BatchTokenizer # intercept the logging nonsense here -from levanter.logging import silence_transformer_nag from levanter.models.asr_model import AudioTextExample from levanter.store.cache import CacheOptions, TreeCache, build_or_load_cache from levanter.utils.jax_utils import key_iterator +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() # noqa diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 053372207..1532a7d06 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -30,9 +30,6 @@ from levanter.data import AsyncDataset from levanter.data.dataset import MappedAsyncDataset from levanter.data.mixture import MixtureDataset, StopStrategy - -# intercept the logging nonsense here -from levanter.logging import silence_transformer_nag # noqa from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmExample from levanter.store.cache import CacheOptions, TreeCache @@ -41,6 +38,9 @@ from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import HfTokenizer, num_cpus_used_by_tokenizer +# intercept the logging nonsense here +from levanter.utils.logging import silence_transformer_nag # noqa + silence_transformer_nag() # noqa from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 6d9165cfc..c066e55d5 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -18,8 +18,8 @@ from levanter.data.mixture import MixtureDataset from levanter.tracker import capture_time from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState -from levanter.types import ComputeLossFunction from levanter.utils.tree_utils import inference_mode +from levanter.utils.types import ComputeLossFunction logger = logging.getLogger(__name__) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 9fe9ab0d7..6f40888cd 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -17,10 +17,10 @@ import levanter.tracker from levanter.data import AsyncDataset, DataLoader -from levanter.logging import LoadingTimeTrackerIterator from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.trainer import StepInfo from levanter.utils.hf_utils import HfTokenizer, byte_length_of_token +from levanter.utils.logging import LoadingTimeTrackerIterator from levanter.utils.stat_utils import Arrayish, RunningMean from levanter.utils.tree_utils import inference_mode diff --git a/src/levanter/lora.py b/src/levanter/lora.py index 1e0f37d67..cdabee3a5 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -64,10 +64,10 @@ ) from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef, upload_to_hub -from levanter.logging import silence_transformer_nag from levanter.trainer import StepInfo from levanter.utils.cloud_utils import temp_dir_before_upload from levanter.utils.jax_utils import join_key, key_iterator, leaf_key_paths +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 92471e997..73eb518b2 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -6,9 +6,9 @@ from levanter.data.metrics_monitor import LoggingMetricsMonitor, RichMetricsMonitor from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig -from levanter.logging import init_logging from levanter.store.cache import build_or_load_cache from levanter.tracker import NoopConfig, TrackerConfig +from levanter.utils.logging import init_logging logger = logging.getLogger(__name__) diff --git a/src/levanter/models/backpack.py b/src/levanter/models/backpack.py index 715706f8e..42157f947 100644 --- a/src/levanter/models/backpack.py +++ b/src/levanter/models/backpack.py @@ -15,10 +15,10 @@ from haliax.state_dict import ModuleWithStateDictSerialization, StateDict, with_prefix from levanter.compat.hf_checkpoints import HFCheckpointConverter, LmWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask, materialize_mask from levanter.models.gpt2 import ACT2FN, Gpt2Config, Gpt2Transformer from levanter.models.lm_model import LmConfig +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py index 23e2bf6dc..93c360792 100644 --- a/src/levanter/models/gemma.py +++ b/src/levanter/models/gemma.py @@ -14,7 +14,6 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.llama import ( # Gemma attention and MLP is identical to LLama LlamaAttention, @@ -23,8 +22,9 @@ ) from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig -from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag +from levanter.utils.types import BlockFoldable silence_transformer_nag() diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 1d2fe5892..db2ed693c 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -17,10 +17,10 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.lm_model import LmConfig from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 6b04ec540..76a786fd9 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -15,13 +15,13 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.gpt2 import ACT2FN from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig -from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag +from levanter.utils.types import BlockFoldable silence_transformer_nag() diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py index b9f19ef41..d7ac00b83 100644 --- a/src/levanter/models/mistral.py +++ b/src/levanter/models/mistral.py @@ -11,11 +11,11 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaTransformer from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index e77e967d7..8a2d6a1c5 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -19,11 +19,11 @@ import levanter.models.attention from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmConfig from levanter.utils.flop_utils import lm_flops_per_token from levanter.utils.jax_utils import use_cpu_device +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/qwen.py b/src/levanter/models/qwen.py index 807a768ad..7f8afa951 100644 --- a/src/levanter/models/qwen.py +++ b/src/levanter/models/qwen.py @@ -13,13 +13,13 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask, dot_product_attention from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaMlp, LlamaRMSNorm, LlamaTransformer from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import RotaryEmbeddingsConfig -from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag +from levanter.utils.types import BlockFoldable silence_transformer_nag() diff --git a/src/levanter/models/whisper.py b/src/levanter/models/whisper.py index 7239626f7..a9c5d528b 100644 --- a/src/levanter/models/whisper.py +++ b/src/levanter/models/whisper.py @@ -17,10 +17,10 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, ModelWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.asr_model import ASRConfig, ASRMixin from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.lm_model import LmConfig +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 9138d0a47..6e62be7a5 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -42,9 +42,9 @@ from haliax.types import Scalar import levanter.checkpoint -import levanter.logging import levanter.tracker import levanter.tracker.wandb +import levanter.utils.logging from levanter import tracker from levanter.checkpoint import CheckpointerConfig, load_checkpoint_or_initialize from levanter.config import JsonAtom @@ -53,10 +53,10 @@ from levanter.grad_accum import microbatched from levanter.tracker import TrackerConfig, capture_time from levanter.trainer_state import TrainerState, saveable_training_mask -from levanter.types import ComputeLossFunction, FilterSpec from levanter.utils import cloud_utils, fsspec_utils from levanter.utils.jax_utils import create_fsdp_mesh from levanter.utils.tree_utils import inference_mode +from levanter.utils.types import ComputeLossFunction, FilterSpec logger = pylogging.getLogger(__name__) @@ -630,7 +630,7 @@ def initialize(self): self._validate_and_set_defaults() id = self._maybe_set_id() - levanter.logging.init_logging(self.log_dir, f"{id}.log") + levanter.utils.logging.init_logging(self.log_dir, f"{id}.log") _initialize_global_tracker(self.tracker, id) self.ray.initialize() diff --git a/src/levanter/trainer_state.py b/src/levanter/trainer_state.py index 15800bd17..549267681 100644 --- a/src/levanter/trainer_state.py +++ b/src/levanter/trainer_state.py @@ -12,8 +12,8 @@ from haliax.quantization import Fp8Config, apply_updates, fp8_linear_layers, partition_for_grad_overwrite from haliax.types import IntScalar, Scalar -from levanter.types import FilterTree from levanter.utils.jax_utils import is_inexact_arrayish +from levanter.utils.types import FilterTree M = TypeVar("M", bound=PyTree) diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 41e4488d4..08205edf6 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -4,7 +4,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from levanter.logging import silence_transformer_nag +from levanter.utils.logging import silence_transformer_nag from levanter.utils.py_utils import logical_cpu_core_count diff --git a/src/levanter/logging.py b/src/levanter/utils/logging.py similarity index 100% rename from src/levanter/logging.py rename to src/levanter/utils/logging.py diff --git a/src/levanter/types.py b/src/levanter/utils/types.py similarity index 100% rename from src/levanter/types.py rename to src/levanter/utils/types.py From 353caa64c682cf4409f350563cbf27849d70dbe8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:50:47 -0800 Subject: [PATCH 14/34] hijack HF's download so it works with gcs etc. --- src/levanter/compat/hf_checkpoints.py | 201 ++++++++++++++------------ 1 file changed, 107 insertions(+), 94 deletions(-) diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index f4ad33757..dc2f0e16d 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -1,4 +1,5 @@ import abc +import contextlib import dataclasses import json import logging @@ -10,7 +11,6 @@ from dataclasses import dataclass from functools import cached_property from typing import Generic, Optional, Tuple, Type, TypeVar, Union, cast -from urllib.parse import urlparse import draccus import equinox as eqx @@ -21,8 +21,9 @@ import mergedeep import safetensors import safetensors.numpy +import transformers.utils.hub from huggingface_hub import HfApi, hf_hub_download, repo_exists, snapshot_download -from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError +from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError, RepositoryNotFoundError from jax.experimental.multihost_utils import sync_global_devices from jax.random import PRNGKey from jaxtyping import Array @@ -324,11 +325,8 @@ def _infer_config_class(hf_config_class, ref, trust_remote_code): if ref is None: raise ValueError("Must provide either config class or reference_checkpoint") path, rev = ref.model_name_or_path, ref.revision - config = AutoConfig.from_pretrained( - path, - revision=rev, - trust_remote_code=trust_remote_code, - ) + with _patch_hf_hub_download(): + config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=trust_remote_code) clss = type(config) elif isinstance(hf_config_class, str): if ref is None: @@ -423,7 +421,9 @@ def config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) - def hf_config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) -> HfConfig: path, rev = self._get_ref(ref) - config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code) + + with _patch_hf_hub_download(): + config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code) return config def _get_ref(self, ref) -> Tuple[str, Optional[str]]: @@ -450,49 +450,51 @@ def load_state_dict(self, ref: Optional[Union[str, RepoRef]] = None, dtype: Opti except HFValidationError: pass - # TODO: load models from gcs etc. - if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)): - state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype) - elif os.path.exists(os.path.join(id, PYTORCH_MODEL)): - state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype) - else: - try: - model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev) - state_dict = _load_safe_tensors(model_path, dtype) - except (EntryNotFoundError, HFValidationError): - model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev) - state_dict = _load_torch(model_path, dtype) + with _patch_hf_hub_download() as hf_hub_download: + # TODO: load models from gcs etc. + if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)): + state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype) + elif os.path.exists(os.path.join(id, PYTORCH_MODEL)): + state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype) + else: + try: + model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev) + state_dict = _load_safe_tensors(model_path, dtype) + except (EntryNotFoundError, HFValidationError): + model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev) + state_dict = _load_torch(model_path, dtype) - return state_dict + return state_dict def _load_shards(self, id: str, index_file: str, rev: Optional[str], dtype) -> dict: """Load model from sharded files based on the provided index.""" - index_path = os.path.join(id, index_file) - if not os.path.exists(index_path): - # Download the index file if not found locally - index_path = hf_hub_download(id, index_file, revision=rev) - - with open(index_path, "r", encoding="utf-8") as f: - index = json.load(f) - - shard_files = list(set(index["weight_map"].values())) - final_state_dict = {} - - # right now we do safe tensors thing - # where we load into memory then update some dict - if "safetensors" in index_file: - loader = _load_safe_tensors - else: - loader = _load_torch + with _patch_hf_hub_download() as hf_hub_download: + index_path = os.path.join(id, index_file) + if not os.path.exists(index_path): + # Download the index file if not found locally + index_path = hf_hub_download(id, index_file, revision=rev) + + with open(index_path, "r", encoding="utf-8") as f: + index = json.load(f) + + shard_files = list(set(index["weight_map"].values())) + final_state_dict = {} + + # right now we do safe tensors thing + # where we load into memory then update some dict + if "safetensors" in index_file: + loader = _load_safe_tensors + else: + loader = _load_torch - for shard_file in shard_files: - shard_path = os.path.join(id, shard_file) - if not os.path.exists(shard_path): - # Download the shard if not found locally - shard_path = hf_hub_download(id, shard_file, revision=rev) + for shard_file in shard_files: + shard_path = os.path.join(id, shard_file) + if not os.path.exists(shard_path): + # Download the shard if not found locally + shard_path = hf_hub_download(id, shard_file, revision=rev) - shard_state_dict = loader(shard_path, dtype) - final_state_dict.update(shard_state_dict) + shard_state_dict = loader(shard_path, dtype) + final_state_dict.update(shard_state_dict) return final_state_dict @@ -588,22 +590,6 @@ def load_from_state_dict(template, state_dict): lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0)) lev_model = load_from_state_dict(lev_model, state_dict) - # all_arrays: list[jax.Array] = get_backend().live_arrays() - # total_size = sum(a.size * a.itemsize for a in all_arrays) - # print(f"Total size of live arrays: {total_size / 1e9:.2f} GB") - # gc.collect() # sometimes takes a while to free buffers otherwise - # try: - # get_backend().defragment() - # except Exception as e: - # warnings.warn(f"Could not defragment because {e}") - # pass - # all_arrays = get_backend().live_arrays() - # total_size = sum(a.size * a.itemsize for a in all_arrays) - # print(f"Total size of live arrays: {total_size / 1e9:.2f} GB") - # all_arrays = get_backend().live_arrays() - # total_size = sum(a.size * a.itemsize for a in all_arrays) - # print(f"Total size of live arrays: {total_size / 1e9:.2f} GB") - return lev_model def _save_pretrained_local( @@ -874,45 +860,20 @@ def cb(step: StepInfo): return cb -def arbitrary_load_from_hf( - model_name_or_path, from_pretrained_lambda, revision=None, local_cache_dir=None, trust_remote_code=True -) -> Union[HfTokenizer | ProcessorMixin]: - is_url_like = urlparse(model_name_or_path).scheme != "" - if is_url_like: - if revision is not None: - raise ValueError("revision is not supported for URLs") - # tokenizers are directories, so we have to copy them locally - if local_cache_dir is None: - local_cache_dir = tempfile.mkdtemp() - - fs, path = fsspec.core.url_to_fs(model_name_or_path) - fs.get(path, local_cache_dir, recursive=True) - base_path = os.path.basename(path) - return from_pretrained_lambda(os.path.join(local_cache_dir, base_path), trust_remote_code=trust_remote_code) - else: - return from_pretrained_lambda(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code) - - def load_tokenizer(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> HfTokenizer: """Like AutoTokenizer.from_pretrained, but works with gs:// paths or anything on fsspec""" - return arbitrary_load_from_hf( - model_name_or_path, - AutoTokenizer.from_pretrained, - revision=revision, - local_cache_dir=local_cache_dir, - trust_remote_code=trust_remote_code, - ) + with _patch_hf_hub_download(): + return AutoTokenizer.from_pretrained( + model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code + ) def load_processor(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> ProcessorMixin: """Like AutoProcessor.from_pretrained, but works with gs:// paths or anything on fsspec""" - return arbitrary_load_from_hf( - model_name_or_path, - AutoProcessor.from_pretrained, - revision=revision, - local_cache_dir=local_cache_dir, - trust_remote_code=trust_remote_code, - ) + with _patch_hf_hub_download(): + return AutoProcessor.from_pretrained( + model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code + ) _sync_count = 0 @@ -1111,3 +1072,55 @@ def _should_use_cpu_for_checkpoint_loading(): return False if sum(accel_memory) < cpu_memory: return True + + +def _is_hf_hub_model(ref: RepoRef): + api = HfApi() + + try: + api.model_info(repo_id=ref.model_name_or_path) + return True + except RepositoryNotFoundError: + return False + + +@contextlib.contextmanager +def _patch_hf_hub_download(): + """ + Temporarily monkeypatch `hf_hub_download` to handle fsspec URLs, ensuring the temporary directory + persists for the lifetime of the context manager. + """ + original_hf_hub_download = transformers.utils.hub.hf_hub_download + + # Create a temporary directory that persists through the context manager + with tempfile.TemporaryDirectory() as tmpdir: + + def custom_hf_hub_download(*args, **kwargs): + """ + Custom implementation of hf_hub_download to handle fsspec URLs. + """ + repo_id = kwargs.get("repo_id", args[0] if len(args) > 0 else None) + filename = kwargs.get("filename", args[1] if len(args) > 1 else None) + + if repo_id and filename and _is_url_like(repo_id): + fs, path = fsspec.core.url_to_fs(repo_id) + remote_path = os.path.join(path, filename) + local_path = os.path.join(tmpdir, filename) + + if not fs.exists(remote_path): + raise EntryNotFoundError(f"File {remote_path} not found") + + fs.get(remote_path, local_path) + return local_path + + # Fallback to the original implementation + return original_hf_hub_download(*args, **kwargs) + + # Monkeypatch hf_hub_download + transformers.utils.hub.hf_hub_download = custom_hf_hub_download + + try: + yield custom_hf_hub_download + finally: + # Restore the original implementation + transformers.utils.hub.hf_hub_download = original_hf_hub_download From 72fa689a3e454faa70006c0fd78988b00812cd86 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:51:28 -0800 Subject: [PATCH 15/34] missed some renames? --- src/levanter/utils/jax_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 5498306e8..6dae17de5 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -274,7 +274,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return sharding else: # get the existing mesh and find the FSDP axis - fsdp_axis = mesh.axis_names.index(haliax.partitioning.ResourceAxis.DATA) + fsdp_axis = mesh.axis_names.index(hax.partitioning.ResourceAxis.DATA) num_devices = mesh.devices.shape[fsdp_axis] for i in range(len(shape) - 1, -1, -1): @@ -286,7 +286,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return NamedSharding(mesh, PartitionSpec(None)) axis_sharding = [None] * len(shape) - axis_sharding[sharded_axis] = haliax.partitioning.ResourceAxis.DATA + axis_sharding[sharded_axis] = hax.partitioning.ResourceAxis.DATA sharding = NamedSharding(mesh, PartitionSpec(*axis_sharding)) return sharding From 16195b07298fc34774e3d481b05b957c39ec56d6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:52:46 -0800 Subject: [PATCH 16/34] rename maybe_fused_next_token_loss --- src/levanter/models/lm_model.py | 16 ++++--- src/levanter/models/loss.py | 77 +++++++++++++++++++++++++-------- tests/test_text.py | 10 +++-- 3 files changed, 75 insertions(+), 28 deletions(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index ddce51e0f..ae731afcf 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -4,6 +4,7 @@ import draccus import equinox as eqx +import jax import jax.numpy as jnp from jax.random import PRNGKey @@ -11,7 +12,7 @@ from haliax import Axis, NamedArray, NamedOrNumeric from levanter.models.attention import AttentionMask -from levanter.models.loss import next_token_loss +from levanter.models.loss import maybe_fused_next_token_loss LmConfigT = TypeVar("LmConfigT", bound="LmConfig") @@ -58,12 +59,13 @@ def from_prompt_and_completion( ) -> "LmExample": # mask out the prompt tokens loss_mask = hax.arange(Pos) >= prompt_length - 1 - # also mask out the last token - loss_mask *= 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) - + # don't predict the padding if ignore_id is not None: - ignore_mask = tokens != ignore_id - loss_mask *= ignore_mask + targets = hax.roll(tokens, -1, Pos) + loss_mask = loss_mask & (targets != ignore_id) + + # don't predict the last token + loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) if all_causal: attn_mask = AttentionMask.causal() @@ -216,7 +218,7 @@ def compute_next_token_loss( """ activations = model.activations(example.tokens, example.attn_mask, key=key) - loss = next_token_loss( + loss = maybe_fused_next_token_loss( model.Pos, model.Embed, model.Vocab, diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index d705eda4d..bf0bd380e 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -10,7 +10,7 @@ from haliax.nn import cross_entropy_loss_and_log_normalizers -def next_token_loss( +def maybe_fused_next_token_loss( Pos: hax.AxisSelector, Embed: hax.AxisSelector, Vocab: hax.AxisSelector, @@ -46,6 +46,15 @@ def next_token_loss( Pos = pred_embeddings.resolve_axis(Pos) Vocab = pred_lm_head.resolve_axis(Vocab) + if block_size is None: + # Full softmax computation + logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed) + if dtype is not None: + logits = logits.astype(dtype) + + # Shift target tokens to predict the next token + return next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight) + # Shift target tokens to predict the next token target_y = hax.roll(true_ids, -1, Pos) @@ -56,22 +65,6 @@ def next_token_loss( else: loss_mask = not_last_loss_mask - if block_size is None: - # Full softmax computation - logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed) - if dtype is not None: - logits = logits.astype(dtype) - target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) - return cross_entropy_and_logsumexp_penalty( - logits, - Vocab, - target_y_full, - reduction=reduction, - reduction_axis=reduction_axis, - where=loss_mask, - logsumexp_weight=logsumexp_weight, - ) - # Compute the loss with optional block-wise processing return fused_cross_entropy_loss_and_logsumexp_penalty( pred_embeddings, @@ -88,9 +81,57 @@ def next_token_loss( ) +def next_token_loss( + Pos: hax.AxisSelector, + Vocab: hax.AxisSelector, + logits: NamedArray, + true_ids: NamedArray, + loss_mask: Optional[NamedArray] = None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + logsumexp_weight: Optional[float] = None, +): + """ + Compute the next token loss with optional logsumexp penalty. + + Args: + Pos: axis selector for the position axis + Vocab: axis selector for the vocabulary axis + logits: predicted logits + true_ids: true token IDs (not shifted) + loss_mask: mask to apply to the loss + reduction: reduction function or None to disable reduction + reduction_axis: axis to apply reduction. None means all axes + logsumexp_weight: weight for the logsumexp penalty + Returns: + NamedArray: computed loss + """ + Pos = logits.resolve_axis(Pos) + + target_y = hax.roll(true_ids, -1, Pos) + target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=logits.dtype) + + # Create a mask that excludes the last token + not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore + if loss_mask is not None: + loss_mask = loss_mask * not_last_loss_mask + else: + loss_mask = not_last_loss_mask + + return cross_entropy_and_logsumexp_penalty( + Vocab=Vocab, + pred_y=logits, + target_y=target_y_full, + reduction=reduction, + reduction_axis=reduction_axis, + where=loss_mask, + logsumexp_weight=logsumexp_weight, + ) + + def cross_entropy_and_logsumexp_penalty( - pred_y: NamedArray, Vocab: hax.Axis, + pred_y: NamedArray, target_y: NamedArray, *, reduction: Optional[hax.ReductionFunction] = hax.mean, diff --git a/tests/test_text.py b/tests/test_text.py index e4e51acbc..f293a9429 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -7,7 +7,7 @@ from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample -from levanter.models.loss import next_token_loss +from levanter.models.loss import maybe_fused_next_token_loss from tests.test_utils import skip_if_hf_model_not_accessible @@ -39,8 +39,12 @@ def test_lm_example_handles_ignore_id(): lm_head = hax.zeros((Embed, Vocab)) lm_head = lm_head.at[Vocab, ignore_id].set(-100) - ignored_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask) - no_ignore_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask) + ignored_loss = maybe_fused_next_token_loss( + Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask + ) + no_ignore_loss = maybe_fused_next_token_loss( + Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask + ) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size From e2cab795ff837f5ced212639d92a8de0a5a14932 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:53:11 -0800 Subject: [PATCH 17/34] add some more tests to make sure different seq lens work --- tests/test_llama.py | 45 +++++++++++++++++++++++++++++++++++++++----- tests/test_llama3.py | 7 +++++-- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/tests/test_llama.py b/tests/test_llama.py index 87576205d..9c08043d8 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -82,7 +82,9 @@ def test_llama_rotary_embedding(): @skip_if_no_torch -def test_apply_rotary_pos_emb(): +@pytest.mark.parametrize("model_seq_len", [128, 256]) +@pytest.mark.parametrize("test_seq_len", [64, 128, 256]) +def test_apply_rotary_pos_emb(model_seq_len, test_seq_len): import torch from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as hf_apply_rotary_pos_emb from transformers.models.llama.modeling_llama import rotate_half as hf_rotate_half @@ -95,9 +97,9 @@ def assert_equal_out(hax_out, torch_out: torch.Tensor): def named_array_to_tensor(named_array): return torch.from_numpy(np.array(named_array.array)) - llama_config = _get_llama_config() + llama_config = _get_llama_config(seq_len=model_seq_len) - Pos = llama_config.Pos + Pos = llama_config.Pos.resize(test_seq_len) Heads = llama_config.Heads HeadSize = llama_config.HeadSize Batch = hax.Axis("batch", 2) @@ -138,7 +140,8 @@ def named_array_to_tensor(named_array): @skip_if_no_torch @pytest.mark.parametrize("use_flash", [True, False]) @pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) -def test_llama_attention(use_flash, num_kv_heads): +@pytest.mark.parametrize("test_seq_len", [64, 128, 256]) +def test_llama_attention(use_flash, num_kv_heads, test_seq_len): import torch from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention @@ -154,7 +157,10 @@ def test_llama_attention(use_flash, num_kv_heads): x, mask = _get_random_inputs(config) x_torch = torch.from_numpy(np.array(x.array)) batch_size = x_torch.shape[0] - explicit_mask = torch.from_numpy(np.array(mask.materialize(config.Pos, config.KeyPos).array)) + test_Pos = config.Pos.resize(test_seq_len) + test_KeyPos = config.KeyPos.resize(test_seq_len) + + explicit_mask = torch.from_numpy(np.array(mask.materialize(test_Pos, test_KeyPos).array)) mask_torch = explicit_mask.broadcast_to((batch_size, 1, -1, -1)) # the torch mask is really a bias, so we need to invert it and make it a big negative number @@ -389,3 +395,32 @@ def test_state_dict_consistency(scan_layers, num_kv_heads): hf_model = LlamaForCausalLM(hf_config) levanter_state_dict = hax.state_dict.to_torch_compatible_state_dict(model) assert set(hf_model.state_dict().keys()) == set(levanter_state_dict.keys()) + + +@pytest.mark.parametrize("num_kv_heads", [2, 4]) +def test_llama_seq_len_doesnt_change_predictions(num_kv_heads): + config = LlamaConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, + ) + Vocab = hax.Axis("vocab", 1000) + + # Make input and attn_mask + input_256 = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + input_128 = input_256[config.Pos, :128] + attn_mask = AttentionMask.causal() + + model = LlamaLMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(0)) + + @hax.named_jit + def compute(model, input): + model_output = model(input, attn_mask=attn_mask) + return model_output + + jax_out_1 = compute(model, input_128) + jax_out_2 = compute(model, input_256)[config.Pos, :128] + + assert np.allclose(jax_out_1.array, jax_out_2.array, rtol=1e-6, atol=1e-6) diff --git a/tests/test_llama3.py b/tests/test_llama3.py index 653ba723c..9bf4a7fbc 100644 --- a/tests/test_llama3.py +++ b/tests/test_llama3.py @@ -2,6 +2,7 @@ import tempfile import numpy as np +import pytest from jax import random import haliax as hax @@ -65,7 +66,8 @@ def get_config(vocab_size=1000): @skip_if_no_torch -def test_llama_roundtrip(): +@pytest.mark.parametrize("test_seq_len", [128, 256, 512]) +def test_llama3_roundtrip(test_seq_len): import torch from transformers import AutoModelForCausalLM, LlamaForCausalLM @@ -77,7 +79,8 @@ def test_llama_roundtrip(): config = LlamaConfig.from_hf_config(hf_config) # Make input and attn_mask - input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + test_Pos = config.Pos.resize(test_seq_len) + input = hax.random.randint(random.PRNGKey(0), test_Pos, 0, Vocab.size) attn_mask = AttentionMask.causal() input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) From 4c34eecd1f09f05a646ce84528fa84f351a0f46b Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:53:33 -0800 Subject: [PATCH 18/34] bump jax version --- docker/tpu/Dockerfile.base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/tpu/Dockerfile.base b/docker/tpu/Dockerfile.base index d276c974d..e2e032e82 100644 --- a/docker/tpu/Dockerfile.base +++ b/docker/tpu/Dockerfile.base @@ -5,7 +5,7 @@ RUN pip install virtualenv # venv binaries encode their directory, so we need to setup the venv in the final location RUN virtualenv -p python3.10 /opt/levanter/.venv ENV PATH /opt/levanter/.venv/bin:$PATH -#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.34" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install package dependencies to make incremental builds faster. From 4ecc630c2ddb991de78fa103939e579e1244d21e Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:53:44 -0800 Subject: [PATCH 19/34] depend on my fork --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ecb3dd21b..e713b73b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ dependencies = [ "async-lru~=2.0", "tqdm-loggable>=0.2", "deepdiff", - "lm-eval==0.4.2", +# "lm-eval==0.4.2", + "lm-eval @ git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch" ] [project.urls] From bcfc225a80a6ded59b0e901f993322d3c4ea2b76 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:55:30 -0800 Subject: [PATCH 20/34] eval_harness is about there --- config/harness/eval_marin_dclm_ckpt.yaml | 27 +++ src/levanter/eval_harness.py | 201 +++++++++++++++-------- 2 files changed, 157 insertions(+), 71 deletions(-) create mode 100644 config/harness/eval_marin_dclm_ckpt.yaml diff --git a/config/harness/eval_marin_dclm_ckpt.yaml b/config/harness/eval_marin_dclm_ckpt.yaml new file mode 100644 index 000000000..f503fcb2c --- /dev/null +++ b/config/harness/eval_marin_dclm_ckpt.yaml @@ -0,0 +1,27 @@ +eval_harness: + task_spec: ["hellaswag"] +# max_examples: 9984 # this is the max that ends up being divisible by 512 after expansion + max_examples: 8 # this is the max that ends up being divisible by 512 after expansion + max_eval_length: 128 +#tokenizer: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 +#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/hf/step-715001/ +#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/step-510000/ +#tokenizer: "EleutherAI/gpt-neox-20b" +tokenizer: meta-llama/Meta-Llama-3-8B +model: + type: llama +#checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 +checkpoint_path: meta-llama/Meta-Llama-3-8B +checkpoint_is_hf: true +trainer: + mp: f32 + profiler: true + + per_device_parallelism: -1 + train_batch_size: 512 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + ray: + auto_start_cluster: false diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 5a72f856e..bde8f688b 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -3,21 +3,23 @@ # https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py # https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6 import dataclasses +import functools import json import logging import typing -import warnings from dataclasses import dataclass from functools import cached_property -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple import equinox as eqx import jax import jax.numpy as jnp -import transformers -from levanter.compat.hf_checkpoints import HFCheckpointConverter +import haliax + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer from levanter.models.gpt2 import Gpt2Config +from levanter.models.loss import next_token_loss try: @@ -33,15 +35,14 @@ from tqdm import tqdm import haliax as hax -from haliax.nn import cross_entropy_loss from haliax.partitioning import round_axis_for_partitioning import levanter.config from levanter.checkpoint import load_checkpoint -from levanter.data import batched +from levanter.data import AsyncDataset, DataLoader from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import TrainerConfig -from levanter.utils.jax_utils import stack_tree, use_cpu_device +from levanter.utils.jax_utils import use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -76,94 +77,133 @@ class _RequestType: FINISHED = 3 +@functools.partial(jax.jit, static_argnums=(0, 3)) +def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): + tokens = hax.named(tokens, Pos) + return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) + + +class EvalDataset(AsyncDataset[LmExample]): + def __init__(self, Pos, tokenizer, examples: list[Instance]): + super().__init__() + self.examples = examples + self.Pos = Pos + self.tokenizer = tokenizer + + async def async_len(self) -> int: + return len(self.examples) + + async def final_length_is_known(self) -> bool: + return True + + def is_finite(self) -> bool: + return True + + async def current_len(self) -> Optional[int]: + return len(self.examples) + + async def get_batch(self, indices: Sequence[int]) -> List[LmExample]: + out = [] + pad_token_id = self.tokenizer.pad_token_id + + reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices] + + for context, completion in reqs: + whole_enc = self.tokenizer(context + completion) + context_enc = self.tokenizer(context) + + context_enc_len = len(context_enc["input_ids"]) + + tokens, length = self._truncate_or_pad(whole_enc, context_enc_len) + example = _jit_create_example(self.Pos, tokens, length, pad_token_id) + + out.append(example) + + return out + + def _truncate_or_pad(self, encoded, prompt_length): + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + ex_pad = self.tokenizer.pad( + encoded, + padding="max_length", + max_length=self.Pos.size, + return_tensors="np", + ) + + truncated = ex_pad["input_ids"][-self.Pos.size :] + # if we truncated the prompt, we need to adjust the prompt length + if len(truncated) < len(encoded): + prompt_length -= len(encoded) - len(truncated) + if prompt_length < 0: + prompt_length = 0 + logger.warning("Prompt length is negative after truncation. Setting to 0.") + + return truncated, prompt_length + + class LevanterHarnessLM(LM): - def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): + def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): super().__init__() self.EvalBatch = EvalBatch + self.EvalPos = EvalPos self.model = model self.axis_resources = axis_resources self.tokenizer = tokenizer def _eval_loglikelihood(model: LmHeadModel, example: LmExample): - logits = model(example.tokens) + logits = model(example.tokens, attn_mask=example.attn_mask) + logits = logits.astype(jnp.float32) + Pos = logits.resolve_axis(self.EvalPos.name) + + loss = next_token_loss( + Pos=Pos, + Vocab=model.Vocab, + logits=logits, + true_ids=example.tokens, + loss_mask=example.loss_mask, + reduction=hax.sum, + reduction_axis=Pos, + ) - targets = hax.roll(example.tokens, -1, axis=model.Pos.name) - target_y = hax.nn.one_hot(targets, model.Vocab, dtype=logits.dtype) - loss = cross_entropy_loss(logits, model.Vocab, target_y, where=example.loss_mask, reduction_axis=model.Pos) - # to tell if we got the right answer, we want to check that argmax(logits) == tokens wherever loss_mask is 1 + not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool) pred_targets = hax.argmax(logits, axis=model.Vocab) - correct = hax.all(hax.equal(pred_targets, targets) | hax.logical_not(example.loss_mask), axis=model.Pos) + targets = hax.roll(example.tokens, -1, axis=Pos) + freebie = hax.logical_not(example.loss_mask * not_last_loss_mask) + correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos) - return loss, correct + return -loss, correct # no sharded outputs self._jit_loglikelihood = hax.named_jit( _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} ) - def _stack_batch(self, examples): - return stack_tree(self.EvalBatch, examples, pad_to_batch_size=True) - def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ Compute log-likelihood of generating a continuation from a context. Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. - Args: - requests: - - Returns: - """ + dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) - contexts = self.tokenizer([req.args[0] for req in requests])["input_ids"] - completions = self.tokenizer([req.args[1] for req in requests])["input_ids"] - - examples: list[LmExample] = [] - - @hax.named_jit - def _jit_create_example(tokens, prompt_len): - tokens = hax.named(tokens, self.model.Pos) - return LmExample.from_prompt_and_completion( - self.model.Pos, tokens, prompt_len, ignore_id=self.tokenizer.pad_token_id - ) + mesh = haliax.partitioning._get_mesh() - # TODO: offload this to an evalbatchloader - for context, completion in zip(tqdm(contexts, desc="Creating examples"), completions): - tokens, length = self._truncate_or_pad(context, completion) - tokens = jnp.array(tokens) - length = jnp.array(length) - example = _jit_create_example(tokens, length) - examples.append(example) + loader = DataLoader( + self.EvalBatch, dataset, max_buffered_batches=1024, mesh=mesh, axis_resources=self.axis_resources + ) result: list[tuple[float, bool]] = [] - for batch in batched(tqdm(examples, desc="examples", leave=False), self.EvalBatch.size): - logger.info("Processing batch") - batch_example = self._stack_batch(batch) - # batch_example = jax.device_put(batch_example, jax.local_devices()[0]) - out_lls, out_correct = self._jit_loglikelihood(self.model, batch_example) + for batch in tqdm(loader, desc="Loglikelihood", unit="ba"): + out_lls, out_correct = self._jit_loglikelihood(self.model, batch) result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # skip padding - result = result[: len(examples)] + result = result[: len(requests)] return result - def _truncate_or_pad(self, context, completion): - max_len = self.model.Pos.size - if len(completion) > max_len: - warnings.warn(f"Completion is longer than max length {max_len}. Truncating.") - completion = completion[:max_len] - pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id - - if len(context) + len(completion) > max_len: - context = context[-(max_len - len(completion)) :] - else: - # right pad with padding token - context = context + [pad_token_id] * (max_len - len(context) - len(completion)) - - return jnp.array(context + completion), len(context) - def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() @@ -171,8 +211,17 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() -def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_resources, max_examples=None) -> dict: - harness = LevanterHarnessLM(EvalBatch, model, axis_resources, tokenizer) +def run_lm_eval_harness( + model, + task_spec: list[str], + tokenizer, + EvalBatch, + axis_resources, + max_examples: int | None = None, + max_eval_length: int | None = None, +) -> dict: + EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) + harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) @@ -181,13 +230,14 @@ def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_ @dataclass(frozen=True) class LmEvalHarnessConfig: - task_spec: Optional[list[str]] = None - max_examples: Optional[int] = None + task_spec: list[str] | None = None + max_examples: int | None = None + max_eval_length: int | None = None def task_spec_or_default(self): return self.task_spec or [ # "lambada", - # "piqa", + "piqa", "hellaswag", # "winogrande", # "mathqa", @@ -218,7 +268,7 @@ def EvalBatch(self): @cached_property def the_tokenizer(self): - return transformers.AutoTokenizer.from_pretrained(self.tokenizer) + return load_tokenizer(self.tokenizer) def run_eval_harness_main(config: EvalHarnessConfig): @@ -244,10 +294,10 @@ def run_eval_harness_main(config: EvalHarnessConfig): # initialize the model if config.checkpoint_is_hf: model_config = config.model - converter: HFCheckpointConverter = model_config.default_hf_checkpoint_converter # type: ignore + converter: HFCheckpointConverter = model_config.hf_checkpoint_converter() converter = converter.replaced(reference_checkpoint=config.checkpoint_path, tokenizer=tokenizer) model = converter.load_pretrained( - model_config.model_type, model_config, ref=config.checkpoint_path, dtype=config.trainer.mp.compute_dtype # type: ignore + model_config.model_type, ref=config.checkpoint_path, dtype=config.trainer.mp.compute_dtype # type: ignore ) else: with use_cpu_device(): @@ -265,14 +315,23 @@ def run_eval_harness_main(config: EvalHarnessConfig): config.EvalBatch, axis_resources=compute_axis_mapping, max_examples=max_examples, + max_eval_length=config.eval_harness.max_eval_length, ) logger.info("Finished running LM eval harness") # log the results as json with open("lm_eval_results.json", "w") as f: - json.dump(outputs, f, indent=2) + # also write to stdout + if jax.process_index() == 0: + print(json.dumps(outputs, indent=2), flush=True) + + # also log the results + levanter.tracker.current_tracker().log_artifact("lm_eval_results.json") + + return outputs + if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() From e0ef6f8d5c54ccf922777cbbf789d64409dd12db Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 25 Nov 2024 21:52:06 -0800 Subject: [PATCH 21/34] refactor --- src/levanter/callbacks.py | 44 --------------- src/levanter/eval_harness.py | 103 +++++++++++++++++++++------------- src/levanter/main/train_lm.py | 4 +- 3 files changed, 67 insertions(+), 84 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index cb8c016c0..983750685 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -19,14 +19,12 @@ import levanter.tracker from levanter.data import AsyncDataset, DataLoader -from levanter.eval_harness import LmEvalHarnessConfig from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils import flop_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.utils.logging import save_xla_dumps_to_wandb -from levanter.utils.tree_utils import inference_mode from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -425,45 +423,3 @@ def _tqdm_logging_one_time_setup(): return _did_tqdm_logging_one_time_setup = True tqdm_logging.tqdm_logging.set_log_rate(timedelta(seconds=60)) - - -def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): - from levanter.eval_harness import run_lm_eval_harness - - def lm_eval_harness(step: StepInfo, force=False): - if step.step == 0 and not force: - return # don't run eval on the first step - - model = inference_mode(step.model, True) - outputs = run_lm_eval_harness( - model, - config.task_spec_or_default(), - tokenizer, - EvalBatch, - axis_resources, - max_examples=config.max_examples, - ) - - if jax.process_index() == 0: - with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: - import json - - json.dump(outputs, f) - levanter.tracker.current_tracker().log_artifact( - f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output" - ) - - # also log accuracy statistics etc - metrics_to_log = {} - for task, metrics in outputs["results"].items(): - for metric, value in metrics.items(): - if metric.endswith(",none"): - metric = metric[: -len(",none")] - - if metric != "alias": - # levanter.tracker.log_metrics({f"lm_eval/{task}/{metric}": value}, step=step.step) - metrics_to_log[f"lm_eval/{task}/{metric}"] = value - - levanter.tracker.log_metrics(metrics_to_log, step=step.step) - - return lm_eval_harness diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index bde8f688b..e6b2eb0bc 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -6,6 +6,7 @@ import functools import json import logging +import tempfile import typing from dataclasses import dataclass from functools import cached_property @@ -17,6 +18,7 @@ import haliax +import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer from levanter.models.gpt2 import Gpt2Config from levanter.models.loss import next_token_loss @@ -32,7 +34,7 @@ evaluator = object # tasks = object -from tqdm import tqdm +from tqdm_loggable.auto import tqdm import haliax as hax from haliax.partitioning import round_axis_for_partitioning @@ -41,7 +43,7 @@ from levanter.checkpoint import load_checkpoint from levanter.data import AsyncDataset, DataLoader from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel -from levanter.trainer import TrainerConfig +from levanter.trainer import StepInfo, TrainerConfig from levanter.utils.jax_utils import use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -49,40 +51,6 @@ logger = logging.getLogger(__name__) -# Ok this is a bit complicated to do because it's distributed systems and that's always hard. -# The idea is that we want to pass an LM adaptor to the harness, and then the harness will call the LM adaptor -# with a request, which we'll format, shard, and send to the model. The model will then return the result to the harness -# which will then return the result to the user. - -# As we so often do, we will coordinate execution through JAX itself. - -# Process 0 will: -# - Pass an adaptor to the eval harness -# - The eval harness will call the adaptor with a request -# - When a request comes in, it will call broadcast_one_to_all with a (REQUEST_TYPE, request) to send the request -# - It then invokes the model with the request and returns the result to the eval harness -# - When finished, it will call broadcast_one_to_all with a (FINISHED_TYPE, result) to send the result - -# Process 1..n will: -# - Wait for a (REQUEST_TYPE, request) broadcast -# - if FINISHED_TYPE, break -# - Invoke the model with the request -# - loop - - -class _RequestType: - LOG_LIKELIHOOD = 0 - GENERATE_UNTIL = 1 - LOG_LIKELIHOOD_ROLLING = 2 - FINISHED = 3 - - -@functools.partial(jax.jit, static_argnums=(0, 3)) -def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): - tokens = hax.named(tokens, Pos) - return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) - - class EvalDataset(AsyncDataset[LmExample]): def __init__(self, Pos, tokenizer, examples: list[Instance]): super().__init__() @@ -211,6 +179,12 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() +@functools.partial(jax.jit, static_argnums=(0, 3)) +def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): + tokens = hax.named(tokens, Pos) + return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) + + def run_lm_eval_harness( model, task_spec: list[str], @@ -219,11 +193,12 @@ def run_lm_eval_harness( axis_resources, max_examples: int | None = None, max_eval_length: int | None = None, + log_samples: bool = False, ) -> dict: EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) - outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) + outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=log_samples) return outputs @@ -233,6 +208,7 @@ class LmEvalHarnessConfig: task_spec: list[str] | None = None max_examples: int | None = None max_eval_length: int | None = None + log_samples: bool = False def task_spec_or_default(self): return self.task_spec or [ @@ -242,9 +218,9 @@ def task_spec_or_default(self): # "winogrande", # "mathqa", # "pubmedqa", - # "boolq", + "boolq", # "cb", - # "copa", + "copa", # "multirc", # "record", # "wic", @@ -316,6 +292,7 @@ def run_eval_harness_main(config: EvalHarnessConfig): axis_resources=compute_axis_mapping, max_examples=max_examples, max_eval_length=config.eval_harness.max_eval_length, + log_samples=config.eval_harness.log_samples, ) logger.info("Finished running LM eval harness") @@ -329,9 +306,57 @@ def run_eval_harness_main(config: EvalHarnessConfig): # also log the results levanter.tracker.current_tracker().log_artifact("lm_eval_results.json") + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) return outputs +def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter.tracker.Tracker] = None): + if tracker is None: + tracker = levanter.tracker.current_tracker() + + to_log = {} + for task_name, task_results in report["results"].items(): + for metric_name, metric_value in task_results.items(): + if metric_name.ends_with(",none"): + metric_name = metric_name[:-5] + + if isinstance(metric_value, float | int): + to_log[f"{prefix}/{task_name}/{metric_name}"] = metric_value + + tracker.log(to_log, step=None) + + +def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): + def lm_eval_harness(step: StepInfo, force=False): + if step.step == 0 and not force: + return # don't run eval on the first step + + model = inference_mode(step.model, True) + outputs = run_lm_eval_harness( + model, + config.task_spec_or_default(), + tokenizer, + EvalBatch, + axis_resources, + max_examples=config.max_examples, + max_eval_length=config.max_eval_length, + log_samples=config.log_samples, + ) + + if jax.process_index() == 0: + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: + import json + + json.dump(outputs, f) + levanter.tracker.current_tracker().log_artifact( + f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output" + ) + + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + + return lm_eval_harness + + if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 9c598b63c..cf327956b 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -13,6 +13,8 @@ from haliax.partitioning import named_jit, round_axis_for_partitioning import levanter +import levanter.eval +import levanter.eval_harness from levanter import callbacks from levanter.checkpoint import EpochCheckpointer, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback @@ -253,7 +255,7 @@ def main(config: TrainLmConfig): if config.eval_harness is not None: eval_harness = config.eval_harness trainer.add_hook( - callbacks.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), + levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), every=config.eval_harness_steps, ) From e52b61d70ce48a1a24d2b20e8317e57e28f0e716 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 25 Nov 2024 22:17:58 -0800 Subject: [PATCH 22/34] pad number of requests to proper length --- src/levanter/eval_harness.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index e6b2eb0bc..37626a8b0 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -154,6 +154,10 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. """ + # pad requests to be a multiple of the batch size + initial_length = len(requests) + dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) + requests = requests + [dummy_instance] * (len(requests) % self.EvalBatch.size) dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) mesh = haliax.partitioning._get_mesh() @@ -168,7 +172,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # skip padding - result = result[: len(requests)] + result = result[:initial_length] return result From babaee2922f36ba2673fd4119876201b772c83cd Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 25 Nov 2024 22:21:12 -0800 Subject: [PATCH 23/34] sigh --- src/levanter/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 6e62be7a5..42b4d9332 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -796,7 +796,7 @@ def _validate_and_set_defaults(self): if self.per_device_eval_parallelism == -1: self.per_device_eval_parallelism = self.per_device_parallelism - if self.replica_dcn_axis_size == -1: + if self.replica_dcn_axis_size == -1 or self.replica_dcn_axis_size is None: self.replica_dcn_axis_size = self.num_slices logger.info(f"Setting replica_dcn_axis_size to {self.replica_dcn_axis_size}") From dbf534176468a7bf67d47918e9b0eb1811ab0fc3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 25 Nov 2024 22:49:09 -0800 Subject: [PATCH 24/34] revert --- src/levanter/trainer.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 42b4d9332..fb353592d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -556,11 +556,11 @@ class TrainerConfig: ) # overrides axis_mapping for parameter """logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred""" - # Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings https://cloud.google.com/tpu/docs/multislice-introduction - replica_ici_axis_size: int = 1 # how many parameter replicas there should be "within" each slice (ICI) - model_axis_size: int = 1 # axis size for tensor parallelism (TP) - replica_dcn_axis_size: Optional[int] = None # how many parameter replicas there should be "across" slices (DCN) - auto_replicas: bool = True # whether to automatically set replica_dcn_axis_size based on num_slices + """Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings https://cloud.google.com/tpu/docs/multislice-introduction""" + replica_ici_axis_size: int = 1 + model_axis_size: int = 1 + """how many devices within each slice for sharding with DP. Fix TP=1, the rest of the devices is for FSDP.""" + replica_dcn_axis_size: int = 1 """how many slices in the multislice scheme for sharding with DP and TP. The rest of the devices is for FSDP.""" # Config related to batch sizes @@ -599,14 +599,10 @@ class TrainerConfig: @property def TrainBatch(self): - if self.train_batch_size <= 0: - raise ValueError("batch_size must be positive. Did you call initialize?") return Axis("batch", self.train_batch_size) @property def EvalBatch(self): - if self.eval_batch_size <= 0: - raise ValueError("eval_batch_size must be positive. Did you call initialize?") return Axis("batch", self.eval_batch_size) @property @@ -654,7 +650,7 @@ def device_mesh(self) -> Mesh: self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size, - self.replica_dcn_axis_size, # type: ignore + self.replica_dcn_axis_size, self.data_dcn_axis_size, ) @@ -769,15 +765,6 @@ def _validate_and_set_defaults(self): ): raise ValueError("either model_axis_size or local_device_count must be divisible by the other") - # handle replica_dcn_axis_size - if self.replica_dcn_axis_size is None: - if self.auto_replicas: - if self.num_slices > 1: - logger.info(f"Setting replica_dcn_axis_size to {self.num_slices}") - self.replica_dcn_axis_size = self.num_slices - else: - self.replica_dcn_axis_size = 1 - assert self.train_batch_size != -1 or self.per_device_parallelism != -1 if self.per_device_parallelism == -1: @@ -796,7 +783,7 @@ def _validate_and_set_defaults(self): if self.per_device_eval_parallelism == -1: self.per_device_eval_parallelism = self.per_device_parallelism - if self.replica_dcn_axis_size == -1 or self.replica_dcn_axis_size is None: + if self.replica_dcn_axis_size == -1: self.replica_dcn_axis_size = self.num_slices logger.info(f"Setting replica_dcn_axis_size to {self.replica_dcn_axis_size}") From f8983991a40809844fd058b44339a26b8f4f7fa5 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Thu, 21 Nov 2024 23:29:32 -0800 Subject: [PATCH 25/34] Optim config drop stable and add decay (#818) --- docs/Configuration-Guide.md | 2 +- src/levanter/optim/config.py | 12 ++++++++---- tests/test_optimizer_config.py | 35 +++++++++++++++------------------- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index f20488ee2..0b00c0800 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -302,7 +302,7 @@ which are common to all optimizers (and most have to do with learning rate sched | `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` | | `min_lr_ratio` | The minimum learning rate ratio. | `0.1` | | `warmup` | Warmup fraction or number of steps | `0.01` | -| `stable` | Stable fraction or number of steps | `0.0` | +| `decay` | Decay fraction or number of steps | `None` | | `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` | | `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` | diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index d814a6b64..7b684efeb 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -26,8 +26,8 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): """The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]""" warmup: float = 0.01 """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - stable: float = 0.00 - """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" + decay: Optional[float] = None + """fraction of training steps to use as decay, or steps to use. None means full decay""" rewarmup: float = 0.0 "If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup." cooldown: Optional[float] = None @@ -174,8 +174,12 @@ def lr_scheduler(self, num_train_steps): schedules.append(warmup) boundaries.append(start + warmup_steps) - stable_steps = _convert_ratio_or_steps(self.stable, cycle_steps) - lr_decay_steps = cycle_steps - stable_steps - warmup_steps + lr_decay_steps = ( + _convert_ratio_or_steps(self.decay, cycle_steps) + if self.decay is not None + else cycle_steps - warmup_steps + ) + stable_steps = cycle_steps - warmup_steps - lr_decay_steps if stable_steps != 0: stable = optax.constant_schedule(self.learning_rate) diff --git a/tests/test_optimizer_config.py b/tests/test_optimizer_config.py index 9c5b91d7c..70737df7c 100644 --- a/tests/test_optimizer_config.py +++ b/tests/test_optimizer_config.py @@ -8,11 +8,10 @@ def test_no_stable_weirdness(): learning_rate=2e-6, # 2x10^-6 weight_decay=0.0, warmup=0.03, - stable=0.0, min_lr_ratio=0.0, lr_schedule="linear", max_grad_norm=None, - haps=None, + cycles=None, weight_decay_modules=None, default_weight_decay_mask=None, ) @@ -33,10 +32,8 @@ def test_constant_schedule(): learning_rate=1e-3, weight_decay=0.0, warmup=0.0, - stable=0.0, min_lr_ratio=1.0, # No decay lr_schedule="constant", - haps=None, cycles=None, ) @@ -52,10 +49,8 @@ def test_warmup_and_cosine_decay(): learning_rate=1e-2, weight_decay=0.0, warmup=0.1, # 10% of steps - stable=0.0, min_lr_ratio=0.1, lr_schedule="cosine", - haps=None, cycles=None, ) @@ -75,7 +70,6 @@ def test_linear_schedule_with_cycles(): learning_rate=5e-4, weight_decay=0.0, warmup=50, - stable=0.0, min_lr_ratio=0.2, lr_schedule="linear", cycles=2, @@ -105,30 +99,33 @@ def test_linear_schedule_with_cycles(): assert np.isclose(sched_fn(999), 0.2 * 5e-4, atol=1e-5) -def test_haps_schedule(): +def test_wsds_schedule(): optimizer = AdamConfig( learning_rate=1e-3, weight_decay=0.0, warmup=0.0, - stable=0.0, + decay=0.1, min_lr_ratio=0.1, lr_schedule="cosine", - haps=[300, 700], + cycles=[300, 700], ) sched_fn = optimizer.lr_scheduler(1000) - # Before first haps + # First cycle assert np.isclose(sched_fn(0), 1e-3) + assert np.isclose(sched_fn(269), 1e-3) + assert sched_fn(271) < 1e-3 - # First haps + # Second cycle assert np.isclose(sched_fn(300), 1e-3) + assert np.isclose(sched_fn(659), 1e-3) + assert sched_fn(661) < 1e-3 - # After first haps - assert sched_fn(301) < 1e-3 - - # Before second haps - assert sched_fn(699) < sched_fn(301) + # Thrid cycle + assert np.isclose(sched_fn(701), 1e-3) + assert np.isclose(sched_fn(969), 1e-3) + assert sched_fn(971) < 1e-3 def test_inv_sqrt_decay_schedule(): @@ -136,10 +133,9 @@ def test_inv_sqrt_decay_schedule(): learning_rate=1e-3, weight_decay=0.0, warmup=0.1, - stable=0.0, min_lr_ratio=0.1, lr_schedule="inv_sqrt", - haps=None, + cycles=None, ) sched_fn = optimizer.lr_scheduler(100_000) @@ -157,7 +153,6 @@ def test_rewarmup_schedule(): learning_rate=1e-2, weight_decay=0.0, warmup=0.2, # 20% of cycle - stable=0.0, min_lr_ratio=0.2, lr_schedule="linear", cycles=2, From 52354159fa2d2d1c31220014780dc8d12d0e8754 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 00:31:24 -0800 Subject: [PATCH 26/34] Bump fsspec (#824) --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e713b73b8..47fa528d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,17 +25,17 @@ dependencies = [ "equinox>=0.11.7", "jaxtyping>=0.2.34", "tokenizers>=0.15.2", - "transformers>=4.41.2,<4.46.0", + "transformers>=4.41.2,<4.47.0", "optax>=0.1.9", "wandb>=0.17.8", "draccus>=0.9.3", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets>=3.1.0,<4.0", - "gcsfs>=2024.2,<2024.10", + "gcsfs>=2024.2,<2025", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec[http]>=2024.2,<2024.10", + "fsspec[http]>=2024.2,<2025", "tensorstore>=0.1.65", "pytimeparse>=1.1.8", "humanfriendly==10.0", From 4d78749659e7ee4f34881e436f56a6db719ea863 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 21:12:53 -0800 Subject: [PATCH 27/34] add cycle_length (#825) --- docs/Configuration-Guide.md | 60 ++++++++++++++++++------------- src/levanter/optim/config.py | 66 ++++++++++++++++++++++++---------- tests/test_optimizer_config.py | 63 +++++++++++++++++++++++++++++++- 3 files changed, 146 insertions(+), 43 deletions(-) diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index 0b00c0800..0df811d79 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -295,21 +295,21 @@ All optimizers in Levanter are based on the [levanter.optim.OptimizerConfig][] d which are common to all optimizers (and most have to do with learning rate scheduling): -| Parameter | Description | Default | -|-----------------|-----------------------------------------------------------------------|----------| -| `weight_decay` | The weight decay. | `0.0` | -| `learning_rate` | The learning rate. | `1e-4` | -| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` | -| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` | -| `warmup` | Warmup fraction or number of steps | `0.01` | -| `decay` | Decay fraction or number of steps | `None` | -| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` | -| `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` | - -By default, Levanter uses a cosine learning rate schedule with a warmup. The learning rate is decayed to +| Parameter | Description | Default | +|-----------------|-------------------------------------------------------------------------------|----------| +| `weight_decay` | The weight decay. | `0.0` | +| `learning_rate` | The learning rate. | `1e-4` | +| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` | +| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` | +| `warmup` | Warmup fraction or number of steps | `0.01` | +| `decay` | Decay fraction or number of steps | `None` | +| `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` | +| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` | +| `cycle_length` | How long the cycles should be (as an int, fraction), or list of cycle lengths | `None` | + +By default, Levanter uses a cosine learning rate decay with warmup. The learning rate is decayed to `min_lr_ratio * learning_rate` over the course of the training run. This is a fairly standard default for LLM training. - #### Learning Rate Schedules The `lr_schedule` parameter specifies the learning rate schedule. The following schedules are supported: @@ -328,8 +328,11 @@ By default, there is only one cycle, and Levanter's LR schedule looks like this: [warmup] -> [stable] -> [decay] ``` -But you can specify more with the `cycles` parameter. If you specify an int for `cycles`, the -learning rate will cycle through the schedule `cycles` times. Levanter's LR schedule looks like this: +But you can specify more with either the `cycles` or `cycle_length` parameters. +If you want to use a learning rate schedule with cycles, you can specify the number of cycles with the `cycles` +or `cycle_length` parameters. The LR will be decayed to `min_lr_ratio * learning_rate` at the end of each cycle. +With cycles, Levanter's LR schedule looks like this: + ``` [warmup] -> [stable] -> [decay] -> {[rewarmup] -> [stable] -> [decay]} x (cycles - 1) @@ -348,27 +351,37 @@ Here's what the phases mean: * `decay`: The decay period. The LR will decay to `min_lr_ratio * learning_rate` over this period. * `rewarmup`: The re-warmup period. If using cycles, the LR will be re-warmed from the final value of the previous cycle back to the peak value of the next cycle. +Also note that if *rewarmup* is 0, there will be no rewarmup period, meaning the LR will jump +back to the max LR. This is the default, and works surprisingly well. In addition, the stable +and decay phase of the first cycle will generally be different from the stable and decay phase of the other cycles, +since rewarmup and warmup are typically different. + +`stable` cannot be specified directly. It is the period between `warmup` and `decay` in the first cycle, and the period +between `rewarmup` and `decay` in subsequent cycles. By default, there is no `stable` period. + All of these parameters can be specified in terms of a fraction of the total number of steps of a cycle or as an absolute number of steps. -If you want to use a learning rate schedule with cycles, you can specify the number of cycles with the `cycles` -parameter. The LR will be decayed to `min_lr_ratio * learning_rate` at the end of each cycle. +Here are what the `cycles` and `cycle_length` parameters mean: + +* `cycle_length`: If you specify an int or float for `cycle_length`, the learning rate will cycle through the +schedule with the specified length. This is equivalent to specifying `cycles` as `num_train_steps / cycle_length`. +If `cycle_length` is a float < 1.0, it is interpreted as a fraction of the total number of steps. +If you specify a list of ints, the learning rate will cycle through the schedule with the specified cycle lengths. +* `cycles`: If you specify an int for `cycles`, the learning rate will cycle through the schedule `cycles` times. +If you specify a list of ints, the learning rate will cycle through the schedule with the specified steps as the minima +of the cycles. +It is an error to specify both `cycles` and `cycle_length`. You can also specify `cycles` as a list, e.g. `[10000, 25000, 50000]`. In this case, `cycles` is interpreted as the minima for the cycles, with the first and final steps being cycle minima as well. `cycles` as an int is equivalent to list `cycles` with the low points evenly spaced at `[num_train_steps / (c + 1)]`. -Also note that if *rewarmup* is 0, there will be no rewarmup period, meaning the LR will jump -back to the max LR. This is the default. In addition, the stable -and decay phase of the first cycle will generally be different from the stable and decay phase of the other cycles, -since rewarmup and warmup are typically different. - See [our paper on WSD-S](https://arxiv.org/pdf/2410.05192) for more information on cyclic LR schedules for training LLMs with short or no rewarmup. - ### AdamConfig Additionally, [levanter.optim.AdamConfig][] has the following fields: @@ -381,7 +394,6 @@ Additionally, [levanter.optim.AdamConfig][] has the following fields: | `max_grad_norm` | The maximum gradient norm (for clipping). | `1.0` | - ## LM Model Config [levanter.models.lm_model.LmConfig][] is a Draccus "choice class" that acts as a base class for all autoregressive diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index 7b684efeb..40be5576d 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -8,6 +8,7 @@ import draccus import equinox as eqx import jax +import numpy as np import optax from jax import numpy as jnp @@ -24,16 +25,18 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): min_lr_ratio: float = 0.1 """The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]""" - warmup: float = 0.01 + warmup: int | float = 0.01 """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - decay: Optional[float] = None + decay: int | float | None = None """fraction of training steps to use as decay, or steps to use. None means full decay""" - rewarmup: float = 0.0 + rewarmup: int | float = 0.0 "If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup." cooldown: Optional[float] = None """Deprecated, as its semantics are confusing.""" - cycles: int | None | list[int] = None - """ Number of cycles to use. If None or 1, use a single cycle. Overriden by haps.""" + cycle_length: int | float | None | list[int] = None + """ Length of cycle. If <= 1, it is treated as a fraction of the total number of steps. None is equivalent to 1.0.""" + cycles: int | list[int] | None = None + """Number of cycles or a list of cycle endpoints. Can use at most one of cycle_length, cycles, or haps.""" lr_schedule: str = "cosine" # constant, cosine, linear haps: Optional[list[int]] = None @@ -145,16 +148,13 @@ def mask_fn(model): def lr_scheduler(self, num_train_steps): if self.cooldown is not None: warnings.warn("cooldown is deprecated. Just use the normal schedule.", DeprecationWarning) - cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) + cooldown_steps = _convert_frac_or_steps(self.cooldown, num_train_steps) else: cooldown_steps = 0 total_main_steps = num_train_steps - cooldown_steps cooldown_points = self._get_cycle_minima(total_main_steps) - cooldown_points.insert(0, 0) - cooldown_points.append(num_train_steps) - min_lr = self.learning_rate * self.min_lr_ratio schedules = [] @@ -165,9 +165,9 @@ def lr_scheduler(self, num_train_steps): for cycle, (start, end) in enumerate(zip(cooldown_points[:-1], cooldown_points[1:])): cycle_steps = end - start if cycle == 0: # warmup - warmup_steps = _convert_ratio_or_steps(self.warmup, cycle_steps) + warmup_steps = _convert_frac_or_steps(self.warmup, cycle_steps) else: - warmup_steps = _convert_ratio_or_steps(self.rewarmup, cycle_steps) + warmup_steps = _convert_frac_or_steps(self.rewarmup, cycle_steps) if warmup_steps != 0: warmup = optax.linear_schedule(previous_end, self.learning_rate, warmup_steps) @@ -175,7 +175,7 @@ def lr_scheduler(self, num_train_steps): boundaries.append(start + warmup_steps) lr_decay_steps = ( - _convert_ratio_or_steps(self.decay, cycle_steps) + _convert_frac_or_steps(self.decay, cycle_steps) if self.decay is not None else cycle_steps - warmup_steps ) @@ -218,7 +218,31 @@ def lr_scheduler(self, num_train_steps): return schedule def _get_cycle_minima(self, total_main_steps): - if self.haps is not None: + if self.cycle_length is not None: + if self.cycles is not None: + raise ValueError("Can't use both cycle_length and cycles.") + if self.haps is not None: + warnings.warn("haps is deprecated. Use cycles instead.", DeprecationWarning) + raise ValueError("Can't use both cycle_length and haps.") + + if isinstance(self.cycle_length, int | float): + cycle_length = _convert_frac_or_steps(self.cycle_length, total_main_steps) + cooldown_points = [i * cycle_length for i in range(1, total_main_steps // cycle_length)] + if total_main_steps % cycle_length != 0: + warnings.warn( + "Cycle length does not divide total number of steps. The last cycle will be shorter." + ) + + elif isinstance(self.cycle_length, list): + lengths = np.array(self.cycle_length) + steps = np.cumsum(lengths) + if steps[-1] > total_main_steps: + raise ValueError(f"Cycle lengths exceed total number of steps: {steps[-1]} > {total_main_steps}") + cooldown_points = steps.tolist() + else: + raise ValueError("Invalid cycle_length. Must be a fraction, number of steps, or a list of steps.") + + elif self.haps is not None: warnings.warn("haps is deprecated. Use cycles instead.", DeprecationWarning) cooldown_points = list(self.haps) elif isinstance(self.cycles, int): @@ -228,6 +252,9 @@ def _get_cycle_minima(self, total_main_steps): cooldown_points = list(self.cycles) else: cooldown_points = [] + + cooldown_points.insert(0, 0) + cooldown_points.append(total_main_steps) return cooldown_points @@ -247,11 +274,14 @@ def schedule(count): return schedule -def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): - if ratio_or_steps < 1.0: - return int(ratio_or_steps * num_train_steps) - else: - return int(ratio_or_steps) +def _convert_frac_or_steps(frac_or_steps: float | int, num_train_steps: int): + # if it's greater than 1, it must be a whole number of steps + if frac_or_steps < 0.0 or (frac_or_steps > 1.0 and frac_or_steps % 1 != 0): + raise ValueError(f"Invalid fraction {frac_or_steps}. Must be between 0 and 1. You can also use (whole) steps.") + if frac_or_steps <= 1.0: + return int(frac_or_steps * num_train_steps) + + return int(frac_or_steps) @dataclass diff --git a/tests/test_optimizer_config.py b/tests/test_optimizer_config.py index 70737df7c..8301152cd 100644 --- a/tests/test_optimizer_config.py +++ b/tests/test_optimizer_config.py @@ -122,7 +122,7 @@ def test_wsds_schedule(): assert np.isclose(sched_fn(659), 1e-3) assert sched_fn(661) < 1e-3 - # Thrid cycle + # Third cycle assert np.isclose(sched_fn(701), 1e-3) assert np.isclose(sched_fn(969), 1e-3) assert sched_fn(971) < 1e-3 @@ -182,3 +182,64 @@ def test_rewarmup_schedule(): # Final decay phase assert sched_fn(999 - 1) > sched_fn(999) assert np.isclose(sched_fn(999), 0.2e-2, atol=1e-4) # End of second decay + + +def test_linear_schedule_with_cycle_length(): + optimizer = AdamConfig( + learning_rate=5e-4, + weight_decay=0.0, + warmup=50, + min_lr_ratio=0.2, + lr_schedule="linear", + cycle_length=500, + ) + + sched_fn = optimizer.lr_scheduler(1000) + + # Warmup phase + assert np.isclose(sched_fn(0), 0.0) + assert np.isclose(sched_fn(50), 5e-4) + + num_main_steps = 1000 + + # First cycle decay + assert np.isclose(sched_fn(499), 0.2 * 5e-4, atol=1e-5) + + # Second cycle starts + assert np.isclose(sched_fn(500), 5e-4) + + # midway through second cycle + midpoint = 500 - 1 + num_main_steps // 4 + assert np.isclose(sched_fn(midpoint), (5e-4 + 0.2 * 5e-4) / 2, atol=1e-5) + + # Final value + assert np.isclose(sched_fn(999), 0.2 * 5e-4, atol=1e-5) + + +def test_wsds_schedule_with_cycle_points(): + optimizer = AdamConfig( + learning_rate=1e-3, + weight_decay=0.0, + warmup=0.0, + decay=0.1, + min_lr_ratio=0.1, + lr_schedule="cosine", + cycle_length=[300, 400], + ) + + sched_fn = optimizer.lr_scheduler(1000) + + # First cycle + assert np.isclose(sched_fn(0), 1e-3) + assert np.isclose(sched_fn(269), 1e-3) + assert sched_fn(271) < 1e-3 + + # Second cycle + assert np.isclose(sched_fn(300), 1e-3) + assert np.isclose(sched_fn(659), 1e-3) + assert sched_fn(661) < 1e-3 + + # Third cycle + assert np.isclose(sched_fn(701), 1e-3) + assert np.isclose(sched_fn(969), 1e-3) + assert sched_fn(971) < 1e-3 From ae9624af1292f9f42185642836d151fcda77e315 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 25 Nov 2024 22:52:20 -0800 Subject: [PATCH 28/34] stack_tree --- src/levanter/utils/jax_utils.py | 35 --------------------------------- 1 file changed, 35 deletions(-) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 6dae17de5..4ebad0221 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -13,7 +13,6 @@ from jaxtyping import PRNGKeyArray, PyTree import haliax as hax -from haliax import AxisSelector, is_named_array from haliax.jax_utils import is_jax_array_like from haliax.partitioning import ResourceAxis @@ -335,37 +334,3 @@ def estimated_free_device_memory(device) -> Optional[float]: in_use = stats.get("bytes_in_use", 0) return (limit - in_use) // (1024.0**3) - - -# @functools.partial(jax.jit, static_argnums=(0), static_argnames=("batch", "pad_to_batch_size")) -def stack_tree(batch: AxisSelector, individual_datums: list[X], *, pad_to_batch_size: bool) -> X: - """ - Stacks a tree of NamedArrays or arrays into a single array. NamedArrays get a new axis with the name batch_name, - while regular arrays are stacked normally. - - Args: - batch: Axis or str name of the new axis. - individual_datums: The tree of NamedArrays or arrays to stack - pad_to_batch_size: If True, pads the arrays to the size of the batch axis (assuming batch is an axis). If False, stacks them as is. - """ - if pad_to_batch_size and not isinstance(batch, hax.Axis): - raise ValueError("pad_to_batch_size can only be used with an Axis Batch") - - if pad_to_batch_size: - missing_count = batch.size - len(individual_datums) - - def _stack_leaves_unchecked(*leaves): - if is_named_array(leaves[0]): - return hax.stack(batch.name, leaves + tuple(hax.zeros_like(leaves[0]) for _ in range(missing_count))) - else: - return jnp.stack(leaves + tuple(jnp.zeros_like(leaves[0]) for _ in range(missing_count))) - - else: - - def _stack_leaves_unchecked(*leaves): - if is_named_array(leaves[0]): - return hax.stack(hax.axis_name(batch), leaves) - else: - return jnp.stack(leaves) - - return jax.tree_map(_stack_leaves_unchecked, *individual_datums, is_leaf=is_named_array) From e97697ca9b822f4af5f19544359b106bfff407ac Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Nov 2024 00:52:25 -0800 Subject: [PATCH 29/34] ok i think we're good --- config/gpt2_small_fast.yaml | 11 +- config/harness/eval_marin_dclm_ckpt.yaml | 50 +++++++- pyproject.toml | 8 +- src/levanter/eval_harness.py | 148 +++++++++++++++-------- src/levanter/main/train_lm.py | 22 ++-- tests/test_eval_harness.py | 27 +++++ 6 files changed, 196 insertions(+), 70 deletions(-) create mode 100644 tests/test_eval_harness.py diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 4e8ade15e..d9afe4b36 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -27,6 +27,13 @@ optimizer: weight_decay: 0.1 warmup: 0.01 eval_harness: - task_spec: ["lambada", "piqa", "hellaswag"] - max_examples: 32 + task_spec: + - piqa + - task: hellaswag + num_fewshot: 10 + task_alias: hellaswag_10shot + - task: hellaswag + task_alias: hellaswag_0shot + - lambada + max_examples: 128 eval_harness_steps: 1000 diff --git a/config/harness/eval_marin_dclm_ckpt.yaml b/config/harness/eval_marin_dclm_ckpt.yaml index f503fcb2c..cb3373d3c 100644 --- a/config/harness/eval_marin_dclm_ckpt.yaml +++ b/config/harness/eval_marin_dclm_ckpt.yaml @@ -1,8 +1,50 @@ eval_harness: - task_spec: ["hellaswag"] -# max_examples: 9984 # this is the max that ends up being divisible by 512 after expansion - max_examples: 8 # this is the max that ends up being divisible by 512 after expansion - max_eval_length: 128 + task_spec: + # EvalTaskConfig("agieval_lsat_ar", 3), # 3-shot tests in legal domain + # EvalTaskConfig("arc_easy", 10), # 10-shot, four-way MCQ questions involving grade 3-9 basic science + # EvalTaskConfig("arc_challenge", 10), # a (harder) version of arc_easy + # EvalTaskConfig("boolq", 10), # answer yes/no questions based on a passage + # EvalTaskConfig("commonsense_qa", 10), # 5-way multiple-choice questions based on common-sense, everyday scenarios + # EvalTaskConfig("copa", 0), # use causal reasoning to predict the correct outcome of a given scenario + # EvalTaskConfig("hellaswag", 0), # 4-way multiple choice commonsense reasoning dataset + # EvalTaskConfig("hellaswag", 10), # 4-way multiple choice commonsense reasoning dataset + # EvalTaskConfig("lambada", 0), # predict the endings of text passages + # EvalTaskConfig("openbookqa", 0), # 4-way multiple choice question answering task that requires multi-step reasoning + # EvalTaskConfig("piqa", 10), # answer questions based on a passage + # EvalTaskConfig("squadv2", 10), # reading comprehension benchmark + # EvalTaskConfig("wsc273", 0), # Winograd Schema Challenge + # EvalTaskConfig("winogrande", 0), # Winograd challenge, extended to more domains + - task: agieval_lsat_ar # 3-shot tests in legal domain + num_fewshot: 3 + - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science + num_fewshot: 10 + - task: arc_challenge # a (harder) version of arc_easy + num_fewshot: 10 + - task: boolq # answer yes/no questions based on a passage + num_fewshot: 10 + - task: copa # use causal reasoning to predict the correct outcome of a given scenario + num_fewshot: 0 + - task: hellaswag # 4-way multiple choice commonsense reasoning dataset + num_fewshot: 0 + task_alias: hellaswag_0shot + - task: hellaswag # 4-way multiple choice commonsense reasoning dataset + num_fewshot: 10 + task_alias: hellaswag_10shot + - task: lambada # predict the endings of text passages + num_fewshot: 0 + - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning + num_fewshot: 0 + - task: piqa # answer questions based on a passage + num_fewshot: 10 + - task: wsc273 # Winograd Schema Challenge + num_fewshot: 0 + - task: winogrande # Winograd challenge, extended to more domains + num_fewshot: 0 + # - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios + # num_fewshot: 10 +## - task: squadv2 # reading comprehension benchmark +# num_fewshot: 10 + max_eval_length: 4096 #tokenizer: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 #tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/hf/step-715001/ #tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/step-510000/ diff --git a/pyproject.toml b/pyproject.toml index 47fa528d3..00e56ed8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,11 @@ name = "levanter" version = "1.2" authors = [ { name = "David Hall", email = "dlwh@cs.stanford.edu" }, + { name = "Jason Wang"}, + { name = "Ahmed Ahmed"}, { name = "Ivan Zhou", email = "ivanz@stanford.edu" }, + { name = "Will Held"}, + { name = "Virginia Adams"} ] description = "Scalable Training for Foundation Models with Named Tensors and JAX" readme = "README.md" @@ -101,11 +105,11 @@ markers = [ [project.optional-dependencies] test = [ "pytest", + "pytest-forked", + "pytest-asyncio", "flake8", "soundfile", "librosa", - "pytest-forked", - "pytest-asyncio", ] [tool.setuptools.packages.find] diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 37626a8b0..ab3c6d3d2 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -25,14 +25,13 @@ try: - from lm_eval import evaluator, tasks + from lm_eval import evaluator from lm_eval.api.instance import Instance from lm_eval.api.model import LM except ImportError: LM = object Instance = object evaluator = object - # tasks = object from tqdm_loggable.auto import tqdm @@ -157,7 +156,8 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: # pad requests to be a multiple of the batch size initial_length = len(requests) dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) - requests = requests + [dummy_instance] * (len(requests) % self.EvalBatch.size) + requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size) + assert len(requests) % self.EvalBatch.size == 0 dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) mesh = haliax.partitioning._get_mesh() @@ -171,6 +171,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: out_lls, out_correct = self._jit_loglikelihood(self.model, batch) result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) + assert len(result) >= initial_length # skip padding result = result[:initial_length] @@ -189,51 +190,72 @@ def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) -def run_lm_eval_harness( - model, - task_spec: list[str], - tokenizer, - EvalBatch, - axis_resources, - max_examples: int | None = None, - max_eval_length: int | None = None, - log_samples: bool = False, -) -> dict: - EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) - harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) - tasks_to_run = tasks.get_task_dict(task_spec) - outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=log_samples) +@dataclass(frozen=True) +class TaskConfig: + """ + This is a dataclass that represents the configuration for a task in the LM Eval Harness. It is used to specify + the configuration for a task in the LM Eval Harness, and is used to generate the task dictionary that the LM Eval + Harness expects. - return outputs + nb that LM Eval Harness has its own TaskConfig, but its defaults are not the same as just passing in + a dict, and we want the behavior of passing in a dict. + + See Also: + [LM Eval Harness TaskConfig](https://github.com/EleutherAI/lm-evaluation-harness/blob/0ef7548d7c3f01108e7c12900a5e5eb4b4a668f7/lm_eval/api/task.py#L55) + """ + + task: str + task_alias: str | None = None + num_fewshot: int | None = None + + use_prompt: str | None = None + description: str | None = None + target_delimiter: str | None = None + fewshot_delimiter: str | None = None + + def to_dict(self): + base_dict = dataclasses.asdict(self) + return {k: v for k, v in base_dict.items() if v is not None} @dataclass(frozen=True) class LmEvalHarnessConfig: - task_spec: list[str] | None = None + task_spec: list[TaskConfig | str] | None = None max_examples: int | None = None max_eval_length: int | None = None log_samples: bool = False - def task_spec_or_default(self): - return self.task_spec or [ - # "lambada", - "piqa", - "hellaswag", - # "winogrande", - # "mathqa", - # "pubmedqa", - "boolq", - # "cb", - "copa", - # "multirc", - # "record", - # "wic", - # "wsc", - ] + def task_spec_or_default(self) -> list[str | dict]: + if self.task_spec is None: + return ["hellaswag", "piqa"] + return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec] + + def to_task_dict(self) -> dict: + import lm_eval.tasks as tasks + + manager = tasks.TaskManager() + # we need to do it this way b/c i can't figure out how to run e.g. hellaswag 0 shot and 10 shot in a single run + this_tasks = {} + for task in self.task_spec_or_default(): + try: + if isinstance(task, str): + this_tasks.update(tasks.get_task_dict(task, manager)) + else: + our_name = task.get("task_alias", task["task"]) if isinstance(task, dict) else task + our_name = our_name.replace(" ", "_") + task_dict = tasks.get_task_dict([task], manager) + this_task = task_dict.popitem()[1] + # hacky, but this allows us to run multiple instances of the same task with different fewshot settings + this_task.config.task = our_name + this_tasks[our_name] = this_task + except Exception: + logger.exception(f"Failed to load task {task}") + raise ValueError(f"Failed to load task {task}") + return this_tasks @dataclass(frozen=True) -class EvalHarnessConfig: +class EvalHarnessMainConfig: tokenizer: str checkpoint_path: str checkpoint_is_hf: bool = False @@ -251,13 +273,35 @@ def the_tokenizer(self): return load_tokenizer(self.tokenizer) -def run_eval_harness_main(config: EvalHarnessConfig): +def run_lm_eval_harness( + config: LmEvalHarnessConfig, + model, + tokenizer, + EvalBatch, + axis_resources, +) -> dict: + # tasks_to_run = tasks.get_task_dict(config.task_spec_or_default(), tasks.TaskManager()) + tasks_to_run = config.to_task_dict() + + outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources) + + return outputs + + +def _actually_run_eval_harness(config: LmEvalHarnessConfig, model, tasks_to_run, tokenizer, EvalBatch, axis_resources): + max_examples = config.max_examples + max_eval_length = config.max_eval_length + + EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) + harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) + outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=config.log_samples) + return outputs + + +def run_eval_harness_main(config: EvalHarnessMainConfig): config.trainer.initialize() tokenizer = config.the_tokenizer - task_spec = config.eval_harness.task_spec_or_default() - max_examples = config.eval_harness.max_examples - compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -289,14 +333,11 @@ def run_eval_harness_main(config: EvalHarnessConfig): logger.info("Running LM eval harness....") outputs = run_lm_eval_harness( + config.eval_harness, model, - task_spec, tokenizer, config.EvalBatch, axis_resources=compute_axis_mapping, - max_examples=max_examples, - max_eval_length=config.eval_harness.max_eval_length, - log_samples=config.eval_harness.log_samples, ) logger.info("Finished running LM eval harness") @@ -322,7 +363,7 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter. to_log = {} for task_name, task_results in report["results"].items(): for metric_name, metric_value in task_results.items(): - if metric_name.ends_with(",none"): + if metric_name.endswith(",none"): metric_name = metric_name[:-5] if isinstance(metric_value, float | int): @@ -332,20 +373,22 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter. def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): + tasks_to_run = config.to_task_dict() + def lm_eval_harness(step: StepInfo, force=False): - if step.step == 0 and not force: - return # don't run eval on the first step + # if step.step == 0 and not force: + # return # don't run eval on the first step + + print(config.task_spec_or_default()) model = inference_mode(step.model, True) - outputs = run_lm_eval_harness( + outputs = _actually_run_eval_harness( + config, model, - config.task_spec_or_default(), + tasks_to_run, tokenizer, EvalBatch, axis_resources, - max_examples=config.max_examples, - max_eval_length=config.max_eval_length, - log_samples=config.log_samples, ) if jax.process_index() == 0: @@ -364,3 +407,4 @@ def lm_eval_harness(step: StepInfo, force=False): if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() + print("Done", flush=True) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index cf327956b..cf78b81f8 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -126,6 +126,18 @@ def main(config: TrainLmConfig): Pos = config.model.Pos KeyPos = config.model.KeyPos + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to + # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of + # tokens: gpt-2 has 50257, for example. So we round up. + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + logger.info(f"initializing model with key {model_key}") + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + logger.info(f"model initialized with {parameter_count(state.model)} parameters") + # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) @@ -156,16 +168,6 @@ def main(config: TrainLmConfig): ) trainer.add_hook(epoch_checkpointer, every=1) - # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to - # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of - # tokens: gpt-2 has 50257, for example. So we round up. - vocab_size = len(tokenizer) - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) - if vocab_size != Vocab.size: - logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - - state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) - seek_dataloader = True if int(state.step) == 0 and config.initialize_from_checkpoint_path is not None: state = load_checkpoint(state, config.initialize_from_checkpoint_path) diff --git a/tests/test_eval_harness.py b/tests/test_eval_harness.py new file mode 100644 index 000000000..bf5d663ce --- /dev/null +++ b/tests/test_eval_harness.py @@ -0,0 +1,27 @@ +from levanter.eval_harness import LmEvalHarnessConfig, TaskConfig +from test_utils import skip_if_module_missing + + +@skip_if_module_missing("lm_eval") +def test_task_config(): + task_spec = [ + TaskConfig( + task="hellaswag", + task_alias="hellaswag_10shot", + num_fewshot=10, + ), + TaskConfig( + task="hellaswag", + task_alias="hellaswag_5shot", + num_fewshot=5, + ), + "lambada_openai", + ] + + config = LmEvalHarnessConfig( + task_spec=task_spec, + ) + + q = config.to_task_dict() + + assert len(q) == 3 From 65a661c3ef6442111d66cb60c16f698aa8bd19ed Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Nov 2024 19:58:28 -0800 Subject: [PATCH 30/34] ok good enough --- ..._marin_dclm_ckpt.yaml => eval_llama3.yaml} | 24 +-- src/levanter/eval_harness.py | 153 +++++++++++++++++- 2 files changed, 156 insertions(+), 21 deletions(-) rename config/harness/{eval_marin_dclm_ckpt.yaml => eval_llama3.yaml} (50%) diff --git a/config/harness/eval_marin_dclm_ckpt.yaml b/config/harness/eval_llama3.yaml similarity index 50% rename from config/harness/eval_marin_dclm_ckpt.yaml rename to config/harness/eval_llama3.yaml index cb3373d3c..df96a1ba0 100644 --- a/config/harness/eval_marin_dclm_ckpt.yaml +++ b/config/harness/eval_llama3.yaml @@ -1,19 +1,7 @@ eval_harness: task_spec: - # EvalTaskConfig("agieval_lsat_ar", 3), # 3-shot tests in legal domain - # EvalTaskConfig("arc_easy", 10), # 10-shot, four-way MCQ questions involving grade 3-9 basic science - # EvalTaskConfig("arc_challenge", 10), # a (harder) version of arc_easy - # EvalTaskConfig("boolq", 10), # answer yes/no questions based on a passage - # EvalTaskConfig("commonsense_qa", 10), # 5-way multiple-choice questions based on common-sense, everyday scenarios - # EvalTaskConfig("copa", 0), # use causal reasoning to predict the correct outcome of a given scenario - # EvalTaskConfig("hellaswag", 0), # 4-way multiple choice commonsense reasoning dataset - # EvalTaskConfig("hellaswag", 10), # 4-way multiple choice commonsense reasoning dataset - # EvalTaskConfig("lambada", 0), # predict the endings of text passages - # EvalTaskConfig("openbookqa", 0), # 4-way multiple choice question answering task that requires multi-step reasoning - # EvalTaskConfig("piqa", 10), # answer questions based on a passage - # EvalTaskConfig("squadv2", 10), # reading comprehension benchmark - # EvalTaskConfig("wsc273", 0), # Winograd Schema Challenge - # EvalTaskConfig("winogrande", 0), # Winograd challenge, extended to more domains + - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios + num_fewshot: 10 - task: agieval_lsat_ar # 3-shot tests in legal domain num_fewshot: 3 - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science @@ -40,15 +28,11 @@ eval_harness: num_fewshot: 0 - task: winogrande # Winograd challenge, extended to more domains num_fewshot: 0 - # - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios - # num_fewshot: 10 + # requires generation ## - task: squadv2 # reading comprehension benchmark # num_fewshot: 10 + max_examples: 16 max_eval_length: 4096 -#tokenizer: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 -#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/hf/step-715001/ -#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/step-510000/ -#tokenizer: "EleutherAI/gpt-neox-20b" tokenizer: meta-llama/Meta-Llama-3-8B model: type: llama diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index ab3c6d3d2..e6f6beb07 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -15,6 +15,7 @@ import equinox as eqx import jax import jax.numpy as jnp +import numpy as np import haliax @@ -294,10 +295,155 @@ def _actually_run_eval_harness(config: LmEvalHarnessConfig, model, tasks_to_run, EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) - outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=config.log_samples) + # we always log_samples here and filter out the samples later if we don't want them + outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True) + + averages = _compute_averages(outputs) + outputs["averages"] = averages + + if not config.log_samples: + del outputs["samples"] + return outputs +def _compute_averages(outputs): + """ + Compute macro and micro averages of all metrics. + + Args: + outputs: Dictionary with results and samples: + - "results": Dictionary of task-level results. + - "samples": Dictionary of task-level sample counts. + + Returns: + Averages dictionary with macro and micro averages for all metrics. + """ + averages = {} + metric_keys = set() + + # Collect all possible metrics across tasks + for task_results in outputs["results"].values(): + metric_keys.update(k for k in task_results.keys() if "stderr" not in k and k != "alias") + + examples_per_task = [len(task_samples) for task_samples in outputs["samples"].values()] + + # Compute macro and micro averages + for metric in metric_keys: + # Collect valid tasks for this metric + valid_tasks = [ + (task_results.get(metric), examples_per_task[i]) + for i, (task_name, task_results) in enumerate(outputs["results"].items()) + if metric in task_results + ] + + if not valid_tasks: + continue # Skip metrics with no valid tasks + + # Separate metric values and weights + metric_values, this_examples_per_task = zip(*valid_tasks) + + # Compute macro and micro averages + averages["macro_avg_" + metric] = np.mean(metric_values) + averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task) + + return averages + + +NAT_TO_BIT = 1 / np.log(2) + +# eval_harness isn't consistent enough for this to actually be workable +# def _compute_extra_metrics(samples): +# """ +# Compute a few "soft" measures of accuracy for each task, based on the outputs of the eval harness. +# +# Specifically, we compute: +# - "bpb": bits per byte of the correct completion +# - "logprob": log probability of the correct completion +# - "choice_logprob": log probability of the correct choice normalized w.r.t. the other choices +# - "choice_prob_norm": probability of the length-normalized correct choice normalized w.r.t. the other choices +# +# Args: +# samples: Dictionary with task data, where each task has a list of samples. Each sample contains: +# - "doc": The original task document (can include metadata such as the answer key) +# - "target": Index of the correct answer (0-indexed), or +# "doc.answer" for 1-indexed answers. +# - "arguments": List of [input, completion] pairs +# - "resps": List of [log probability, is_correct] pairs for completions +# +# Returns: +# A dictionary with per-task aggregated metrics. +# """ +# # TODO: move to eval harness and use more sane logic +# # uses the samples which has one of two structures (that I've seen) +# # { "": [ {"doc": {...,}, "target": <0-indexed answer>, "arguments": [[input, completion], "resps": [[score, is_correct], ...], ...}, ...] } +# # { "": [ {"doc": {..., "answer": "[1-indexed answer]"}, "target": "", "arguments": [input, completion], "resps": [[score, is_correct], ...], ...}, ...] } +# metrics = {} +# +# for task, samples in samples.items(): +# bpb_list = [] +# logprob_list = [] +# choice_logprob_list = [] +# choice_prob_norm_list = [] +# +# for sample in samples: +# # Extract the correct answer index (supporting both 0-indexed `target` and 1-indexed `doc.answer`) +# if "answer" in sample["doc"]: +# target = int(sample["doc"]["answer"]) - 1 # Convert 1-indexed to 0-indexed +# elif "label" in sample["doc"]: +# target = int(sample["doc"]["label"]) +# elif "target" in sample and isinstance(sample["target"], int): +# target = sample["target"] # 0-indexed target +# elif "target" in sample and isinstance(sample["target"], str): +# # see if it's A-Z: +# if len(sample["target"]) == 1 and "A" <= sample["target"] <= "Z": +# target = ord(sample["target"]) - ord("A") +# else: +# raise ValueError(f"Invalid target: {sample['target']}. {sample}") +# elif "target" in sample and isinstance(sample["target"], list): +# target = sample["target"][0] +# else: +# raise KeyError(f"Missing `target` or `doc.answer` in sample. doc id: {sample['doc_id']}. Hash: {sample['doc_hash']}\n\n{sample}") +# +# resps = sample["filtered_resps"] # List of [log probability, is_correct] +# arguments = sample["arguments"] # [input, completion] pairs +# +# # Compute byte lengths for each choice +# byte_lengths = [max(1, len(completion.encode("utf-8"))) for _, completion in arguments] +# +# # Compute log probabilities for each choice +# log_probs = np.array([resp[0] for resp in resps]) # Extract log probabilities +# assert log_probs.shape == (len(arguments),), f"Log probs shape: {log_probs.shape}, arguments: {len(arguments)}. doc: {sample}" +# normalized_log_probs = log_probs - np.logaddexp.reduce(log_probs) +# +# # Metrics for the correct answer +# correct_logprob = log_probs[target] +# correct_bpb = -correct_logprob / byte_lengths[target] * NAT_TO_BIT +# correct_choice_logprob = normalized_log_probs[target] +# +# # Compute length-normalized weights (w_i) +# bpb_values = -log_probs / np.array(byte_lengths) * NAT_TO_BIT +# bpb_weights = np.exp(-bpb_values) +# bpb_weights /= max(bpb_weights.sum(), 1e-8) # Avoid division by zero +# correct_choice_prob_norm = bpb_weights[target] +# +# # Append metrics +# bpb_list.append(correct_bpb) +# logprob_list.append(correct_logprob) +# choice_logprob_list.append(correct_choice_logprob) +# choice_prob_norm_list.append(correct_choice_prob_norm) +# +# # Aggregate metrics for the task +# metrics[task] = { +# "bpb": np.mean(bpb_list) if bpb_list else 0.0, +# "logprob": np.mean(logprob_list) if logprob_list else 0.0, +# "choice_logprob": np.mean(choice_logprob_list) if choice_logprob_list else 0.0, +# "choice_prob_norm": np.mean(choice_prob_norm_list) if choice_prob_norm_list else 0.0, +# } +# +# return metrics + + def run_eval_harness_main(config: EvalHarnessMainConfig): config.trainer.initialize() tokenizer = config.the_tokenizer @@ -369,6 +515,11 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter. if isinstance(metric_value, float | int): to_log[f"{prefix}/{task_name}/{metric_name}"] = metric_value + if "averages" in report: + for metric_name, metric_value in report["averages"].items(): + if isinstance(metric_value, float | int): + to_log[f"{prefix}/averages/{metric_name}"] = metric_value + tracker.log(to_log, step=None) From aeb8095fb41a2178bd2eb9e43577295b44adefdb Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Nov 2024 19:58:44 -0800 Subject: [PATCH 31/34] kk --- config/gpt2_small_fast.yaml | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index d9afe4b36..054977899 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -26,14 +26,3 @@ optimizer: learning_rate: 1E-3 weight_decay: 0.1 warmup: 0.01 -eval_harness: - task_spec: - - piqa - - task: hellaswag - num_fewshot: 10 - task_alias: hellaswag_10shot - - task: hellaswag - task_alias: hellaswag_0shot - - lambada - max_examples: 128 -eval_harness_steps: 1000 From 3179f67107ecf953010a04f5900c1803d127760e Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Nov 2024 20:01:33 -0800 Subject: [PATCH 32/34] remove max_examples --- config/harness/eval_llama3.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/config/harness/eval_llama3.yaml b/config/harness/eval_llama3.yaml index df96a1ba0..260620102 100644 --- a/config/harness/eval_llama3.yaml +++ b/config/harness/eval_llama3.yaml @@ -31,7 +31,6 @@ eval_harness: # requires generation ## - task: squadv2 # reading comprehension benchmark # num_fewshot: 10 - max_examples: 16 max_eval_length: 4096 tokenizer: meta-llama/Meta-Llama-3-8B model: From 69ea6b1b028e133d9c234a5d6d1006a2f53f7f9c Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Nov 2024 21:43:40 -0800 Subject: [PATCH 33/34] initialize --- src/levanter/main/train_lm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index cf78b81f8..2ce2135fd 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -134,10 +134,6 @@ def main(config: TrainLmConfig): if vocab_size != Vocab.size: logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - logger.info(f"initializing model with key {model_key}") - state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) - logger.info(f"model initialized with {parameter_count(state.model)} parameters") - # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) @@ -168,6 +164,8 @@ def main(config: TrainLmConfig): ) trainer.add_hook(epoch_checkpointer, every=1) + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + seek_dataloader = True if int(state.step) == 0 and config.initialize_from_checkpoint_path is not None: state = load_checkpoint(state, config.initialize_from_checkpoint_path) From 5a4e6ce11702fd313279e4bbf7b929a86f97af72 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Nov 2024 21:48:53 -0800 Subject: [PATCH 34/34] remove none --- src/levanter/eval_harness.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index e6f6beb07..9f6faee23 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -518,6 +518,9 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter. if "averages" in report: for metric_name, metric_value in report["averages"].items(): if isinstance(metric_value, float | int): + if metric_name.endswith(",none"): + metric_name = metric_name[:-5] + to_log[f"{prefix}/averages/{metric_name}"] = metric_value tracker.log(to_log, step=None)