diff --git a/config/gpt2_nano_harness.yaml b/config/gpt2_nano_harness.yaml new file mode 100644 index 000000000..8a241a058 --- /dev/null +++ b/config/gpt2_nano_harness.yaml @@ -0,0 +1,26 @@ +eval_harness: + task_spec: ["piqa", "hellaswag"] + max_examples: 32 +eval_harness_steps: 50 +data: + id: dlwh/wikitext_103_detokenized +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_parallelism: -1 + train_batch_size: 4 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" diff --git a/config/harness/eval_llama3.yaml b/config/harness/eval_llama3.yaml new file mode 100644 index 000000000..260620102 --- /dev/null +++ b/config/harness/eval_llama3.yaml @@ -0,0 +1,52 @@ +eval_harness: + task_spec: + - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios + num_fewshot: 10 + - task: agieval_lsat_ar # 3-shot tests in legal domain + num_fewshot: 3 + - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science + num_fewshot: 10 + - task: arc_challenge # a (harder) version of arc_easy + num_fewshot: 10 + - task: boolq # answer yes/no questions based on a passage + num_fewshot: 10 + - task: copa # use causal reasoning to predict the correct outcome of a given scenario + num_fewshot: 0 + - task: hellaswag # 4-way multiple choice commonsense reasoning dataset + num_fewshot: 0 + task_alias: hellaswag_0shot + - task: hellaswag # 4-way multiple choice commonsense reasoning dataset + num_fewshot: 10 + task_alias: hellaswag_10shot + - task: lambada # predict the endings of text passages + num_fewshot: 0 + - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning + num_fewshot: 0 + - task: piqa # answer questions based on a passage + num_fewshot: 10 + - task: wsc273 # Winograd Schema Challenge + num_fewshot: 0 + - task: winogrande # Winograd challenge, extended to more domains + num_fewshot: 0 + # requires generation +## - task: squadv2 # reading comprehension benchmark +# num_fewshot: 10 + max_eval_length: 4096 +tokenizer: meta-llama/Meta-Llama-3-8B +model: + type: llama +#checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 +checkpoint_path: meta-llama/Meta-Llama-3-8B +checkpoint_is_hf: true +trainer: + mp: f32 + profiler: true + + per_device_parallelism: -1 + train_batch_size: 512 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + ray: + auto_start_cluster: false diff --git a/config/harness/harness_nano.yaml b/config/harness/harness_nano.yaml new file mode 100644 index 000000000..833291e5c --- /dev/null +++ b/config/harness/harness_nano.yaml @@ -0,0 +1,24 @@ +eval_harness: + task_spec: ["hellaswag"] +tokenizer: "gpt2" +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + profiler: true + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_parallelism: -1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" diff --git a/config/olmo/olmo_7b_repro.yaml b/config/olmo/olmo_7b_repro.yaml new file mode 100644 index 000000000..aca11c419 --- /dev/null +++ b/config/olmo/olmo_7b_repro.yaml @@ -0,0 +1,175 @@ +#data: !include data/dolma_olmo_paloma.yaml +data: + cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7" + tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo` + # tokenizer: "meta-llama/Llama-2-7b-hf" + stop_strategy: restart + shuffle_buffer_size: 100000 + configs: + dolma-algebraic-stack: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz + dolma-arxiv: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz + dolma-gutenberg: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz + dolma-c4: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz + dolma-cc: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz + dolma-cc-news: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz + dolma-falcon: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz + dolma-megawika: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz + dolma-owmath: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz + dolma-pes2o: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz + dolma-reddit: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz + dolma-stackexchange: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz + dolma-starcoder: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz + dolma-flan: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz + dolma-wiki: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz + # these are just for eval + "paloma/4chan": + validation_urls: + - gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz + "paloma/c4_100_domains": + validation_urls: + - gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz + "paloma/c4_en": + validation_urls: + - gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz + "paloma/dolma-v1_5": + validation_urls: + - gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz + "paloma/dolma_100_programing_languages": + validation_urls: + - gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz + "paloma/dolma_100_subreddits": + validation_urls: + - gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz + "paloma/falcon-refinedweb": + validation_urls: + - gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz + "paloma/gab": + validation_urls: + - gs://levanter-data/paloma/gab/val/val*.jsonl.gz + "paloma/m2d2_s2orc_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz + "paloma/m2d2_wikipedia_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz + "paloma/manosphere_meta_sep": + validation_urls: + - gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz + "paloma/mc4": + validation_urls: + - gs://levanter-data/paloma/mc4/val/val*.jsonl.gz + "paloma/ptb": + validation_urls: + - gs://levanter-data/paloma/ptb/val/val*.jsonl.gz + "paloma/redpajama": + validation_urls: + - gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz + "paloma/twitterAAE_HELM_fixed": + validation_urls: + - gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz + "paloma/wikitext_103": + validation_urls: + - gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz + train_weights: + # sampling proportion comes from https://huggingface.co/datasets/allenai/dolma + dolma-algebraic-stack: 12.6 # 12.6 * 1.0 + dolma-arxiv: 28.0 # 28.0 * 1.0 + dolma-gutenberg: 5.3 # 5.3 * 1.0 + dolma-c4: 69.2 # 138.4 * 0.5 + dolma-cc: 597.75 # 1,195.5 * 0.5 + dolma-cc-news: 14.3 # 1.0 + dolma-falcon: 456.4 # 1.0, refined web + dolma-megawika: 4.6 # 1.0 + dolma-owmath: 12.6 # 1.0 + dolma-pes2o: 57.2 # 1.0 + dolma-reddit: 79.9 # 1.0 + dolma-stackexchange: 19.6 # 1.0 + dolma-starcoder: 263.8 # 1.0 + dolma-flan: 16.5 # 6.5 * 1.0 + dolma-wiki: 7.4 # 3.7 * 2.0 + paloma/4chan: 0.0 + paloma/c4_100_domains: 0.0 + paloma/c4_en: 0.0 + paloma/dolma-v1_5: 0.0 + paloma/dolma_100_programing_languages: 0.0 + paloma/dolma_100_subreddits: 0.0 + paloma/falcon-refinedweb: 0.0 + paloma/gab: 0.0 + paloma/m2d2_s2orc_unsplit: 0.0 + paloma/m2d2_wikipedia_unsplit: 0.0 + paloma/manosphere_meta_sep: 0.0 + paloma/mc4: 0.0 + paloma/ptb: 0.0 + paloma/redpajama: 0.0 + paloma/twitterAAE_HELM_fixed: 0.0 + paloma/wikitext_103: 0.0 +model: # 7B class model + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True +# flash_attention_block_size: 1024 + + use_bias: false + use_layer_norm_weight: false +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 2048 # olmo actually uses 2160 table 5 of https://arxiv.org/pdf/2402.00838 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + replica_dcn_axis_size: 2 +optimizer: + learning_rate: 3E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + beta1: 0.9 + beta2: 0.95 + warmup: 2000 diff --git a/pyproject.toml b/pyproject.toml index a2f89cdeb..00e56ed8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,11 @@ name = "levanter" version = "1.2" authors = [ { name = "David Hall", email = "dlwh@cs.stanford.edu" }, + { name = "Jason Wang"}, + { name = "Ahmed Ahmed"}, { name = "Ivan Zhou", email = "ivanz@stanford.edu" }, + { name = "Will Held"}, + { name = "Virginia Adams"} ] description = "Scalable Training for Foundation Models with Named Tensors and JAX" readme = "README.md" @@ -47,10 +51,11 @@ dependencies = [ "pydantic<3", "rich~=13.0", "filelock~=3.13", - # "ai2-olmo", "async-lru~=2.0", "tqdm-loggable>=0.2", - "deepdiff" + "deepdiff", +# "lm-eval==0.4.2", + "lm-eval @ git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch" ] [project.urls] @@ -100,11 +105,11 @@ markers = [ [project.optional-dependencies] test = [ "pytest", + "pytest-forked", + "pytest-asyncio", "flake8", "soundfile", "librosa", - "pytest-forked", - "pytest-asyncio", ] [tool.setuptools.packages.find] diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index ba684b8e5..db2395b4b 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -380,11 +380,15 @@ def load_checkpoint( logger.warning("Loading checkpoint in jit. This is not recommended and probably won't work.") if discover_latest: - checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore + discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore + else: + discovered_checkpoint_path = checkpoint_path - if checkpoint_path is None or not fs.exists(checkpoint_path): + if discovered_checkpoint_path is None or not fs.exists(discovered_checkpoint_path): raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") + checkpoint_path = discovered_checkpoint_path + logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py new file mode 100644 index 000000000..9f6faee23 --- /dev/null +++ b/src/levanter/eval_harness.py @@ -0,0 +1,564 @@ +# Code for running https://github.com/EleutherAI/lm-evaluation-harness inside Levanter runs +# References: +# https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py +# https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6 +import dataclasses +import functools +import json +import logging +import tempfile +import typing +from dataclasses import dataclass +from functools import cached_property +from typing import List, Optional, Sequence, Tuple + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np + +import haliax + +import levanter.tracker +from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer +from levanter.models.gpt2 import Gpt2Config +from levanter.models.loss import next_token_loss + + +try: + from lm_eval import evaluator + from lm_eval.api.instance import Instance + from lm_eval.api.model import LM +except ImportError: + LM = object + Instance = object + evaluator = object + +from tqdm_loggable.auto import tqdm + +import haliax as hax +from haliax.partitioning import round_axis_for_partitioning + +import levanter.config +from levanter.checkpoint import load_checkpoint +from levanter.data import AsyncDataset, DataLoader +from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel +from levanter.trainer import StepInfo, TrainerConfig +from levanter.utils.jax_utils import use_cpu_device +from levanter.utils.tree_utils import inference_mode + + +logger = logging.getLogger(__name__) + + +class EvalDataset(AsyncDataset[LmExample]): + def __init__(self, Pos, tokenizer, examples: list[Instance]): + super().__init__() + self.examples = examples + self.Pos = Pos + self.tokenizer = tokenizer + + async def async_len(self) -> int: + return len(self.examples) + + async def final_length_is_known(self) -> bool: + return True + + def is_finite(self) -> bool: + return True + + async def current_len(self) -> Optional[int]: + return len(self.examples) + + async def get_batch(self, indices: Sequence[int]) -> List[LmExample]: + out = [] + pad_token_id = self.tokenizer.pad_token_id + + reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices] + + for context, completion in reqs: + whole_enc = self.tokenizer(context + completion) + context_enc = self.tokenizer(context) + + context_enc_len = len(context_enc["input_ids"]) + + tokens, length = self._truncate_or_pad(whole_enc, context_enc_len) + example = _jit_create_example(self.Pos, tokens, length, pad_token_id) + + out.append(example) + + return out + + def _truncate_or_pad(self, encoded, prompt_length): + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + ex_pad = self.tokenizer.pad( + encoded, + padding="max_length", + max_length=self.Pos.size, + return_tensors="np", + ) + + truncated = ex_pad["input_ids"][-self.Pos.size :] + # if we truncated the prompt, we need to adjust the prompt length + if len(truncated) < len(encoded): + prompt_length -= len(encoded) - len(truncated) + if prompt_length < 0: + prompt_length = 0 + logger.warning("Prompt length is negative after truncation. Setting to 0.") + + return truncated, prompt_length + + +class LevanterHarnessLM(LM): + def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): + super().__init__() + self.EvalBatch = EvalBatch + self.EvalPos = EvalPos + self.model = model + self.axis_resources = axis_resources + self.tokenizer = tokenizer + + def _eval_loglikelihood(model: LmHeadModel, example: LmExample): + logits = model(example.tokens, attn_mask=example.attn_mask) + logits = logits.astype(jnp.float32) + Pos = logits.resolve_axis(self.EvalPos.name) + + loss = next_token_loss( + Pos=Pos, + Vocab=model.Vocab, + logits=logits, + true_ids=example.tokens, + loss_mask=example.loss_mask, + reduction=hax.sum, + reduction_axis=Pos, + ) + + not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool) + pred_targets = hax.argmax(logits, axis=model.Vocab) + targets = hax.roll(example.tokens, -1, axis=Pos) + freebie = hax.logical_not(example.loss_mask * not_last_loss_mask) + correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos) + + return -loss, correct + + # no sharded outputs + self._jit_loglikelihood = hax.named_jit( + _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} + ) + + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: + """ + Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + """ + # pad requests to be a multiple of the batch size + initial_length = len(requests) + dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) + requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size) + assert len(requests) % self.EvalBatch.size == 0 + dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) + + mesh = haliax.partitioning._get_mesh() + + loader = DataLoader( + self.EvalBatch, dataset, max_buffered_batches=1024, mesh=mesh, axis_resources=self.axis_resources + ) + + result: list[tuple[float, bool]] = [] + for batch in tqdm(loader, desc="Loglikelihood", unit="ba"): + out_lls, out_correct = self._jit_loglikelihood(self.model, batch) + result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) + + assert len(result) >= initial_length + # skip padding + result = result[:initial_length] + + return result + + def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: + raise NotImplementedError() + + def generate_until(self, requests) -> List[str]: + raise NotImplementedError() + + +@functools.partial(jax.jit, static_argnums=(0, 3)) +def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): + tokens = hax.named(tokens, Pos) + return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) + + +@dataclass(frozen=True) +class TaskConfig: + """ + This is a dataclass that represents the configuration for a task in the LM Eval Harness. It is used to specify + the configuration for a task in the LM Eval Harness, and is used to generate the task dictionary that the LM Eval + Harness expects. + + nb that LM Eval Harness has its own TaskConfig, but its defaults are not the same as just passing in + a dict, and we want the behavior of passing in a dict. + + See Also: + [LM Eval Harness TaskConfig](https://github.com/EleutherAI/lm-evaluation-harness/blob/0ef7548d7c3f01108e7c12900a5e5eb4b4a668f7/lm_eval/api/task.py#L55) + """ + + task: str + task_alias: str | None = None + num_fewshot: int | None = None + + use_prompt: str | None = None + description: str | None = None + target_delimiter: str | None = None + fewshot_delimiter: str | None = None + + def to_dict(self): + base_dict = dataclasses.asdict(self) + return {k: v for k, v in base_dict.items() if v is not None} + + +@dataclass(frozen=True) +class LmEvalHarnessConfig: + task_spec: list[TaskConfig | str] | None = None + max_examples: int | None = None + max_eval_length: int | None = None + log_samples: bool = False + + def task_spec_or_default(self) -> list[str | dict]: + if self.task_spec is None: + return ["hellaswag", "piqa"] + return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec] + + def to_task_dict(self) -> dict: + import lm_eval.tasks as tasks + + manager = tasks.TaskManager() + # we need to do it this way b/c i can't figure out how to run e.g. hellaswag 0 shot and 10 shot in a single run + this_tasks = {} + for task in self.task_spec_or_default(): + try: + if isinstance(task, str): + this_tasks.update(tasks.get_task_dict(task, manager)) + else: + our_name = task.get("task_alias", task["task"]) if isinstance(task, dict) else task + our_name = our_name.replace(" ", "_") + task_dict = tasks.get_task_dict([task], manager) + this_task = task_dict.popitem()[1] + # hacky, but this allows us to run multiple instances of the same task with different fewshot settings + this_task.config.task = our_name + this_tasks[our_name] = this_task + except Exception: + logger.exception(f"Failed to load task {task}") + raise ValueError(f"Failed to load task {task}") + return this_tasks + + +@dataclass(frozen=True) +class EvalHarnessMainConfig: + tokenizer: str + checkpoint_path: str + checkpoint_is_hf: bool = False + trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig) + model: LmConfig = dataclasses.field(default_factory=Gpt2Config) + + eval_harness: LmEvalHarnessConfig = dataclasses.field(default_factory=LmEvalHarnessConfig) + + @property + def EvalBatch(self): + return self.trainer.EvalBatch + + @cached_property + def the_tokenizer(self): + return load_tokenizer(self.tokenizer) + + +def run_lm_eval_harness( + config: LmEvalHarnessConfig, + model, + tokenizer, + EvalBatch, + axis_resources, +) -> dict: + # tasks_to_run = tasks.get_task_dict(config.task_spec_or_default(), tasks.TaskManager()) + tasks_to_run = config.to_task_dict() + + outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources) + + return outputs + + +def _actually_run_eval_harness(config: LmEvalHarnessConfig, model, tasks_to_run, tokenizer, EvalBatch, axis_resources): + max_examples = config.max_examples + max_eval_length = config.max_eval_length + + EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) + harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) + # we always log_samples here and filter out the samples later if we don't want them + outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True) + + averages = _compute_averages(outputs) + outputs["averages"] = averages + + if not config.log_samples: + del outputs["samples"] + + return outputs + + +def _compute_averages(outputs): + """ + Compute macro and micro averages of all metrics. + + Args: + outputs: Dictionary with results and samples: + - "results": Dictionary of task-level results. + - "samples": Dictionary of task-level sample counts. + + Returns: + Averages dictionary with macro and micro averages for all metrics. + """ + averages = {} + metric_keys = set() + + # Collect all possible metrics across tasks + for task_results in outputs["results"].values(): + metric_keys.update(k for k in task_results.keys() if "stderr" not in k and k != "alias") + + examples_per_task = [len(task_samples) for task_samples in outputs["samples"].values()] + + # Compute macro and micro averages + for metric in metric_keys: + # Collect valid tasks for this metric + valid_tasks = [ + (task_results.get(metric), examples_per_task[i]) + for i, (task_name, task_results) in enumerate(outputs["results"].items()) + if metric in task_results + ] + + if not valid_tasks: + continue # Skip metrics with no valid tasks + + # Separate metric values and weights + metric_values, this_examples_per_task = zip(*valid_tasks) + + # Compute macro and micro averages + averages["macro_avg_" + metric] = np.mean(metric_values) + averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task) + + return averages + + +NAT_TO_BIT = 1 / np.log(2) + +# eval_harness isn't consistent enough for this to actually be workable +# def _compute_extra_metrics(samples): +# """ +# Compute a few "soft" measures of accuracy for each task, based on the outputs of the eval harness. +# +# Specifically, we compute: +# - "bpb": bits per byte of the correct completion +# - "logprob": log probability of the correct completion +# - "choice_logprob": log probability of the correct choice normalized w.r.t. the other choices +# - "choice_prob_norm": probability of the length-normalized correct choice normalized w.r.t. the other choices +# +# Args: +# samples: Dictionary with task data, where each task has a list of samples. Each sample contains: +# - "doc": The original task document (can include metadata such as the answer key) +# - "target": Index of the correct answer (0-indexed), or +# "doc.answer" for 1-indexed answers. +# - "arguments": List of [input, completion] pairs +# - "resps": List of [log probability, is_correct] pairs for completions +# +# Returns: +# A dictionary with per-task aggregated metrics. +# """ +# # TODO: move to eval harness and use more sane logic +# # uses the samples which has one of two structures (that I've seen) +# # { "": [ {"doc": {...,}, "target": <0-indexed answer>, "arguments": [[input, completion], "resps": [[score, is_correct], ...], ...}, ...] } +# # { "": [ {"doc": {..., "answer": "[1-indexed answer]"}, "target": "", "arguments": [input, completion], "resps": [[score, is_correct], ...], ...}, ...] } +# metrics = {} +# +# for task, samples in samples.items(): +# bpb_list = [] +# logprob_list = [] +# choice_logprob_list = [] +# choice_prob_norm_list = [] +# +# for sample in samples: +# # Extract the correct answer index (supporting both 0-indexed `target` and 1-indexed `doc.answer`) +# if "answer" in sample["doc"]: +# target = int(sample["doc"]["answer"]) - 1 # Convert 1-indexed to 0-indexed +# elif "label" in sample["doc"]: +# target = int(sample["doc"]["label"]) +# elif "target" in sample and isinstance(sample["target"], int): +# target = sample["target"] # 0-indexed target +# elif "target" in sample and isinstance(sample["target"], str): +# # see if it's A-Z: +# if len(sample["target"]) == 1 and "A" <= sample["target"] <= "Z": +# target = ord(sample["target"]) - ord("A") +# else: +# raise ValueError(f"Invalid target: {sample['target']}. {sample}") +# elif "target" in sample and isinstance(sample["target"], list): +# target = sample["target"][0] +# else: +# raise KeyError(f"Missing `target` or `doc.answer` in sample. doc id: {sample['doc_id']}. Hash: {sample['doc_hash']}\n\n{sample}") +# +# resps = sample["filtered_resps"] # List of [log probability, is_correct] +# arguments = sample["arguments"] # [input, completion] pairs +# +# # Compute byte lengths for each choice +# byte_lengths = [max(1, len(completion.encode("utf-8"))) for _, completion in arguments] +# +# # Compute log probabilities for each choice +# log_probs = np.array([resp[0] for resp in resps]) # Extract log probabilities +# assert log_probs.shape == (len(arguments),), f"Log probs shape: {log_probs.shape}, arguments: {len(arguments)}. doc: {sample}" +# normalized_log_probs = log_probs - np.logaddexp.reduce(log_probs) +# +# # Metrics for the correct answer +# correct_logprob = log_probs[target] +# correct_bpb = -correct_logprob / byte_lengths[target] * NAT_TO_BIT +# correct_choice_logprob = normalized_log_probs[target] +# +# # Compute length-normalized weights (w_i) +# bpb_values = -log_probs / np.array(byte_lengths) * NAT_TO_BIT +# bpb_weights = np.exp(-bpb_values) +# bpb_weights /= max(bpb_weights.sum(), 1e-8) # Avoid division by zero +# correct_choice_prob_norm = bpb_weights[target] +# +# # Append metrics +# bpb_list.append(correct_bpb) +# logprob_list.append(correct_logprob) +# choice_logprob_list.append(correct_choice_logprob) +# choice_prob_norm_list.append(correct_choice_prob_norm) +# +# # Aggregate metrics for the task +# metrics[task] = { +# "bpb": np.mean(bpb_list) if bpb_list else 0.0, +# "logprob": np.mean(logprob_list) if logprob_list else 0.0, +# "choice_logprob": np.mean(choice_logprob_list) if choice_logprob_list else 0.0, +# "choice_prob_norm": np.mean(choice_prob_norm_list) if choice_prob_norm_list else 0.0, +# } +# +# return metrics + + +def run_eval_harness_main(config: EvalHarnessMainConfig): + config.trainer.initialize() + tokenizer = config.the_tokenizer + + compute_axis_mapping = config.trainer.compute_axis_mapping + parameter_axis_mapping = config.trainer.parameter_axis_mapping + + with config.trainer.device_mesh, hax.axis_mapping(parameter_axis_mapping): + key = jax.random.PRNGKey(0) + + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(hax.Axis("vocab", vocab_size), compute_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + model: LmHeadModel + + # initialize the model + if config.checkpoint_is_hf: + model_config = config.model + converter: HFCheckpointConverter = model_config.hf_checkpoint_converter() + converter = converter.replaced(reference_checkpoint=config.checkpoint_path, tokenizer=tokenizer) + model = converter.load_pretrained( + model_config.model_type, ref=config.checkpoint_path, dtype=config.trainer.mp.compute_dtype # type: ignore + ) + else: + with use_cpu_device(): + model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) + model = load_checkpoint(model, config.checkpoint_path, subpath="model") + model = hax.shard(model, parameter_axis_mapping) + + model = typing.cast(LmHeadModel, inference_mode(model, True)) + + logger.info("Running LM eval harness....") + outputs = run_lm_eval_harness( + config.eval_harness, + model, + tokenizer, + config.EvalBatch, + axis_resources=compute_axis_mapping, + ) + + logger.info("Finished running LM eval harness") + # log the results as json + with open("lm_eval_results.json", "w") as f: + json.dump(outputs, f, indent=2) + + # also write to stdout + if jax.process_index() == 0: + print(json.dumps(outputs, indent=2), flush=True) + + # also log the results + levanter.tracker.current_tracker().log_artifact("lm_eval_results.json") + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + + return outputs + + +def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter.tracker.Tracker] = None): + if tracker is None: + tracker = levanter.tracker.current_tracker() + + to_log = {} + for task_name, task_results in report["results"].items(): + for metric_name, metric_value in task_results.items(): + if metric_name.endswith(",none"): + metric_name = metric_name[:-5] + + if isinstance(metric_value, float | int): + to_log[f"{prefix}/{task_name}/{metric_name}"] = metric_value + + if "averages" in report: + for metric_name, metric_value in report["averages"].items(): + if isinstance(metric_value, float | int): + if metric_name.endswith(",none"): + metric_name = metric_name[:-5] + + to_log[f"{prefix}/averages/{metric_name}"] = metric_value + + tracker.log(to_log, step=None) + + +def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): + tasks_to_run = config.to_task_dict() + + def lm_eval_harness(step: StepInfo, force=False): + # if step.step == 0 and not force: + # return # don't run eval on the first step + + print(config.task_spec_or_default()) + + model = inference_mode(step.model, True) + outputs = _actually_run_eval_harness( + config, + model, + tasks_to_run, + tokenizer, + EvalBatch, + axis_resources, + ) + + if jax.process_index() == 0: + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: + import json + + json.dump(outputs, f) + levanter.tracker.current_tracker().log_artifact( + f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output" + ) + + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + + return lm_eval_harness + + +if __name__ == "__main__": + levanter.config.main(run_eval_harness_main)() + print("Done", flush=True) diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 116a08f18..6b4ca8b86 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -99,7 +99,7 @@ def compute_loss(model: LmHeadModel, example: LmExample): if config.hf_checkpoint is not None: # load the huggingface model model_config = config.model - if not hasattr(model_config, "hf_checkpoint_converter"): + if not hasattr(model_config, "default_hf_checkpoint_converter"): raise ValueError("Model config does not have an HF checkpoint converter. Can't load HF checkpoint.") converter: HFCheckpointConverter = model_config.hf_checkpoint_converter() converter = converter.replaced(reference_checkpoint=config.hf_checkpoint, tokenizer=tokenizer) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 99165c017..2ce2135fd 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -13,6 +13,8 @@ from haliax.partitioning import named_jit, round_axis_for_partitioning import levanter +import levanter.eval +import levanter.eval_harness from levanter import callbacks from levanter.checkpoint import EpochCheckpointer, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback @@ -23,6 +25,7 @@ SupervisedSourceConfig, mk_supervised_datasets, ) +from levanter.eval_harness import LmEvalHarnessConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig @@ -56,11 +59,12 @@ class TrainLmConfig: hf_upload: Optional[str] = None hf_save_steps: int = 10000 - update_hessian_steps: int = 10 data_seed: Optional[int] = None # if provided, will override the data seed from the trainer initialize_from_checkpoint_path: Optional[str] = None # if provided, will initialize from this checkpoint, used for llama style data mixture epoch: int = 0 + eval_harness: Optional[LmEvalHarnessConfig] = None + eval_harness_steps: int = 10000 def main(config: TrainLmConfig): @@ -122,6 +126,14 @@ def main(config: TrainLmConfig): Pos = config.model.Pos KeyPos = config.model.KeyPos + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to + # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of + # tokens: gpt-2 has 50257, for example. So we round up. + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) @@ -152,14 +164,6 @@ def main(config: TrainLmConfig): ) trainer.add_hook(epoch_checkpointer, every=1) - # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to - # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of - # tokens: gpt-2 has 50257, for example. So we round up. - vocab_size = len(tokenizer) - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) - if vocab_size != Vocab.size: - logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) seek_dataloader = True @@ -248,6 +252,13 @@ def main(config: TrainLmConfig): every=config.hf_save_steps, ) + if config.eval_harness is not None: + eval_harness = config.eval_harness + trainer.add_hook( + levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), + every=config.eval_harness_steps, + ) + # visualize log probs @named_jit( in_axis_resources=parameter_axis_mapping, diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 7e51ecadd..ae731afcf 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -4,11 +4,12 @@ import draccus import equinox as eqx +import jax import jax.numpy as jnp from jax.random import PRNGKey import haliax as hax -from haliax import Axis, NamedArray +from haliax import Axis, NamedArray, NamedOrNumeric from levanter.models.attention import AttentionMask from levanter.models.loss import maybe_fused_next_token_loss @@ -47,6 +48,33 @@ def causal( attn_mask = AttentionMask.causal() return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + @staticmethod + def from_prompt_and_completion( + Pos, + tokens: hax.NamedArray, + prompt_length: NamedOrNumeric, + *, + ignore_id: Optional[int] = None, + all_causal: bool = True, + ) -> "LmExample": + # mask out the prompt tokens + loss_mask = hax.arange(Pos) >= prompt_length - 1 + # don't predict the padding + if ignore_id is not None: + targets = hax.roll(tokens, -1, Pos) + loss_mask = loss_mask & (targets != ignore_id) + + # don't predict the last token + loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) + + if all_causal: + attn_mask = AttentionMask.causal() + else: + # causal just for the completion part. We don't have a special structured mask for this, so we just + raise NotImplementedError("Not implemented yet") + + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + # TODO: for some reason, mypy doesn't like the discover_packages_path argument? @dataclass(frozen=True) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 1d7205365..4ebad0221 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -12,7 +12,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec, PositionalSharding from jaxtyping import PRNGKeyArray, PyTree -import haliax +import haliax as hax from haliax.jax_utils import is_jax_array_like from haliax.partitioning import ResourceAxis @@ -252,7 +252,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): devices = jax.devices() if mesh is None: - mesh = haliax.partitioning._get_mesh() + mesh = hax.partitioning._get_mesh() if mesh.devices.shape == (): mesh = None @@ -273,7 +273,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return sharding else: # get the existing mesh and find the FSDP axis - fsdp_axis = mesh.axis_names.index(haliax.partitioning.ResourceAxis.DATA) + fsdp_axis = mesh.axis_names.index(hax.partitioning.ResourceAxis.DATA) num_devices = mesh.devices.shape[fsdp_axis] for i in range(len(shape) - 1, -1, -1): @@ -285,7 +285,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return NamedSharding(mesh, PartitionSpec(None)) axis_sharding = [None] * len(shape) - axis_sharding[sharded_axis] = haliax.partitioning.ResourceAxis.DATA + axis_sharding[sharded_axis] = hax.partitioning.ResourceAxis.DATA sharding = NamedSharding(mesh, PartitionSpec(*axis_sharding)) return sharding diff --git a/tests/test_eval_harness.py b/tests/test_eval_harness.py new file mode 100644 index 000000000..bf5d663ce --- /dev/null +++ b/tests/test_eval_harness.py @@ -0,0 +1,27 @@ +from levanter.eval_harness import LmEvalHarnessConfig, TaskConfig +from test_utils import skip_if_module_missing + + +@skip_if_module_missing("lm_eval") +def test_task_config(): + task_spec = [ + TaskConfig( + task="hellaswag", + task_alias="hellaswag_10shot", + num_fewshot=10, + ), + TaskConfig( + task="hellaswag", + task_alias="hellaswag_5shot", + num_fewshot=5, + ), + "lambada_openai", + ] + + config = LmEvalHarnessConfig( + task_spec=task_spec, + ) + + q = config.to_task_dict() + + assert len(q) == 3 diff --git a/tests/test_llama.py b/tests/test_llama.py index 87576205d..9c08043d8 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -82,7 +82,9 @@ def test_llama_rotary_embedding(): @skip_if_no_torch -def test_apply_rotary_pos_emb(): +@pytest.mark.parametrize("model_seq_len", [128, 256]) +@pytest.mark.parametrize("test_seq_len", [64, 128, 256]) +def test_apply_rotary_pos_emb(model_seq_len, test_seq_len): import torch from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as hf_apply_rotary_pos_emb from transformers.models.llama.modeling_llama import rotate_half as hf_rotate_half @@ -95,9 +97,9 @@ def assert_equal_out(hax_out, torch_out: torch.Tensor): def named_array_to_tensor(named_array): return torch.from_numpy(np.array(named_array.array)) - llama_config = _get_llama_config() + llama_config = _get_llama_config(seq_len=model_seq_len) - Pos = llama_config.Pos + Pos = llama_config.Pos.resize(test_seq_len) Heads = llama_config.Heads HeadSize = llama_config.HeadSize Batch = hax.Axis("batch", 2) @@ -138,7 +140,8 @@ def named_array_to_tensor(named_array): @skip_if_no_torch @pytest.mark.parametrize("use_flash", [True, False]) @pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) -def test_llama_attention(use_flash, num_kv_heads): +@pytest.mark.parametrize("test_seq_len", [64, 128, 256]) +def test_llama_attention(use_flash, num_kv_heads, test_seq_len): import torch from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention @@ -154,7 +157,10 @@ def test_llama_attention(use_flash, num_kv_heads): x, mask = _get_random_inputs(config) x_torch = torch.from_numpy(np.array(x.array)) batch_size = x_torch.shape[0] - explicit_mask = torch.from_numpy(np.array(mask.materialize(config.Pos, config.KeyPos).array)) + test_Pos = config.Pos.resize(test_seq_len) + test_KeyPos = config.KeyPos.resize(test_seq_len) + + explicit_mask = torch.from_numpy(np.array(mask.materialize(test_Pos, test_KeyPos).array)) mask_torch = explicit_mask.broadcast_to((batch_size, 1, -1, -1)) # the torch mask is really a bias, so we need to invert it and make it a big negative number @@ -389,3 +395,32 @@ def test_state_dict_consistency(scan_layers, num_kv_heads): hf_model = LlamaForCausalLM(hf_config) levanter_state_dict = hax.state_dict.to_torch_compatible_state_dict(model) assert set(hf_model.state_dict().keys()) == set(levanter_state_dict.keys()) + + +@pytest.mark.parametrize("num_kv_heads", [2, 4]) +def test_llama_seq_len_doesnt_change_predictions(num_kv_heads): + config = LlamaConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, + ) + Vocab = hax.Axis("vocab", 1000) + + # Make input and attn_mask + input_256 = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + input_128 = input_256[config.Pos, :128] + attn_mask = AttentionMask.causal() + + model = LlamaLMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(0)) + + @hax.named_jit + def compute(model, input): + model_output = model(input, attn_mask=attn_mask) + return model_output + + jax_out_1 = compute(model, input_128) + jax_out_2 = compute(model, input_256)[config.Pos, :128] + + assert np.allclose(jax_out_1.array, jax_out_2.array, rtol=1e-6, atol=1e-6) diff --git a/tests/test_llama3.py b/tests/test_llama3.py index 653ba723c..9bf4a7fbc 100644 --- a/tests/test_llama3.py +++ b/tests/test_llama3.py @@ -2,6 +2,7 @@ import tempfile import numpy as np +import pytest from jax import random import haliax as hax @@ -65,7 +66,8 @@ def get_config(vocab_size=1000): @skip_if_no_torch -def test_llama_roundtrip(): +@pytest.mark.parametrize("test_seq_len", [128, 256, 512]) +def test_llama3_roundtrip(test_seq_len): import torch from transformers import AutoModelForCausalLM, LlamaForCausalLM @@ -77,7 +79,8 @@ def test_llama_roundtrip(): config = LlamaConfig.from_hf_config(hf_config) # Make input and attn_mask - input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + test_Pos = config.Pos.resize(test_seq_len) + input = hax.random.randint(random.PRNGKey(0), test_Pos, 0, Vocab.size) attn_mask = AttentionMask.causal() input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0)