Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Eleuther LM-Eval-Harness in Levanter #675

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8d1f3a6
wip
dlwh May 20, 2024
a91afd9
Merge remote-tracking branch 'origin/main' into eval_harness
dlwh May 22, 2024
628c525
ok it runs. garbage but it runs?
dlwh May 23, 2024
d4bf9e2
don't require imports
dlwh May 23, 2024
a32308a
wio
dlwh May 28, 2024
319eb6a
wip
dlwh May 28, 2024
d0a1560
Merge remote-tracking branch 'origin/main' into eval_harness
dlwh Jun 12, 2024
f05b189
almost there
dlwh Jun 12, 2024
7ba7f47
Merge remote-tracking branch 'origin/main' into eval_harness
dlwh Jun 12, 2024
19ac049
maybe we are there?
dlwh Jun 12, 2024
f5dec31
launcher
dlwh Jun 12, 2024
12c0b06
fix (?) logging of loading time etc.
dlwh Jun 12, 2024
463331e
wip
dlwh Jun 13, 2024
2d4d0d8
wip
dlwh Jun 21, 2024
f9ccebb
Merge remote-tracking branch 'origin/main' into eval_harness
dlwh Jul 25, 2024
3040956
Merge remote-tracking branch 'origin/main' into eval_harness
dlwh Nov 15, 2024
c7c5f70
off by one
dlwh Nov 15, 2024
50c20c0
Merge remote-tracking branch 'origin/main' into eval_harness
dlwh Nov 21, 2024
eb441e3
move logging and types to util to make python's module resolution hap…
dlwh Nov 24, 2024
353caa6
hijack HF's download so it works with gcs etc.
dlwh Nov 24, 2024
72fa689
missed some renames?
dlwh Nov 24, 2024
16195b0
rename maybe_fused_next_token_loss
dlwh Nov 24, 2024
e2cab79
add some more tests to make sure different seq lens work
dlwh Nov 24, 2024
4c34eec
bump jax version
dlwh Nov 24, 2024
4ecc630
depend on my fork
dlwh Nov 24, 2024
bcfc225
eval_harness is about there
dlwh Nov 24, 2024
e0ef6f8
refactor
dlwh Nov 26, 2024
e52b61d
pad number of requests to proper length
dlwh Nov 26, 2024
babaee2
sigh
dlwh Nov 26, 2024
dbf5341
revert
dlwh Nov 26, 2024
f898399
Optim config drop stable and add decay (#818)
blahBlahhhJ Nov 22, 2024
5235415
Bump fsspec (#824)
dlwh Nov 24, 2024
4d78749
add cycle_length (#825)
dlwh Nov 25, 2024
2a2a98f
Merge remote-tracking branch 'origin/main' into eval_harness
dlwh Nov 26, 2024
ae9624a
stack_tree
dlwh Nov 26, 2024
e97697c
ok i think we're good
dlwh Nov 29, 2024
65a661c
ok good enough
dlwh Nov 30, 2024
aeb8095
kk
dlwh Nov 30, 2024
3179f67
remove max_examples
dlwh Nov 30, 2024
69ea6b1
initialize
dlwh Nov 30, 2024
5a4e6ce
remove none
dlwh Nov 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions config/gpt2_nano_harness.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
eval_harness:
task_spec: ["piqa", "hellaswag"]
max_examples: 32
eval_harness_steps: 50
data:
id: dlwh/wikitext_103_detokenized
model:
type: gpt2
hidden_dim: 32
num_heads: 4
num_layers: 2
trainer:
mp: f32
num_train_steps: 100

checkpointer:
keep:
- every: 50
save_interval: 5m

per_device_parallelism: -1
train_batch_size: 4

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
52 changes: 52 additions & 0 deletions config/harness/eval_llama3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
eval_harness:
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
num_fewshot: 10
- task: agieval_lsat_ar # 3-shot tests in legal domain
num_fewshot: 3
- task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science
num_fewshot: 10
- task: arc_challenge # a (harder) version of arc_easy
num_fewshot: 10
- task: boolq # answer yes/no questions based on a passage
num_fewshot: 10
- task: copa # use causal reasoning to predict the correct outcome of a given scenario
num_fewshot: 0
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 0
task_alias: hellaswag_0shot
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 10
task_alias: hellaswag_10shot
- task: lambada # predict the endings of text passages
num_fewshot: 0
- task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning
num_fewshot: 0
- task: piqa # answer questions based on a passage
num_fewshot: 10
- task: wsc273 # Winograd Schema Challenge
num_fewshot: 0
- task: winogrande # Winograd challenge, extended to more domains
num_fewshot: 0
# requires generation
## - task: squadv2 # reading comprehension benchmark
# num_fewshot: 10
max_eval_length: 4096
tokenizer: meta-llama/Meta-Llama-3-8B
model:
type: llama
#checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930
checkpoint_path: meta-llama/Meta-Llama-3-8B
checkpoint_is_hf: true
trainer:
mp: f32
profiler: true

per_device_parallelism: -1
train_batch_size: 512

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
ray:
auto_start_cluster: false
24 changes: 24 additions & 0 deletions config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
eval_harness:
task_spec: ["hellaswag"]
tokenizer: "gpt2"
model:
type: gpt2
hidden_dim: 32
num_heads: 4
num_layers: 2
trainer:
mp: f32
num_train_steps: 100
profiler: true

checkpointer:
keep:
- every: 50
save_interval: 5m

per_device_parallelism: -1
train_batch_size: 32

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
175 changes: 175 additions & 0 deletions config/olmo/olmo_7b_repro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#data: !include data/dolma_olmo_paloma.yaml
data:
cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7"
tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo`
# tokenizer: "meta-llama/Llama-2-7b-hf"
stop_strategy: restart
shuffle_buffer_size: 100000
configs:
dolma-algebraic-stack:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz
dolma-arxiv:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz
dolma-gutenberg:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz
dolma-c4:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz
dolma-cc:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz
dolma-cc-news:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz
dolma-falcon:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz
dolma-megawika:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz
dolma-owmath:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz
dolma-pes2o:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz
dolma-reddit:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz
dolma-stackexchange:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz
dolma-starcoder:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz
dolma-flan:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz
dolma-wiki:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz
# these are just for eval
"paloma/4chan":
validation_urls:
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz
"paloma/c4_100_domains":
validation_urls:
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz
"paloma/c4_en":
validation_urls:
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz
"paloma/dolma-v1_5":
validation_urls:
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz
"paloma/dolma_100_programing_languages":
validation_urls:
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz
"paloma/dolma_100_subreddits":
validation_urls:
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz
"paloma/falcon-refinedweb":
validation_urls:
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz
"paloma/gab":
validation_urls:
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz
"paloma/m2d2_s2orc_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz
"paloma/m2d2_wikipedia_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
"paloma/manosphere_meta_sep":
validation_urls:
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz
"paloma/mc4":
validation_urls:
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz
"paloma/ptb":
validation_urls:
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz
"paloma/redpajama":
validation_urls:
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz
"paloma/twitterAAE_HELM_fixed":
validation_urls:
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz
"paloma/wikitext_103":
validation_urls:
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz
train_weights:
# sampling proportion comes from https://huggingface.co/datasets/allenai/dolma
dolma-algebraic-stack: 12.6 # 12.6 * 1.0
dolma-arxiv: 28.0 # 28.0 * 1.0
dolma-gutenberg: 5.3 # 5.3 * 1.0
dolma-c4: 69.2 # 138.4 * 0.5
dolma-cc: 597.75 # 1,195.5 * 0.5
dolma-cc-news: 14.3 # 1.0
dolma-falcon: 456.4 # 1.0, refined web
dolma-megawika: 4.6 # 1.0
dolma-owmath: 12.6 # 1.0
dolma-pes2o: 57.2 # 1.0
dolma-reddit: 79.9 # 1.0
dolma-stackexchange: 19.6 # 1.0
dolma-starcoder: 263.8 # 1.0
dolma-flan: 16.5 # 6.5 * 1.0
dolma-wiki: 7.4 # 3.7 * 2.0
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
# flash_attention_block_size: 1024

use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 2048 # olmo actually uses 2160 table 5 of https://arxiv.org/pdf/2402.00838
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
replica_dcn_axis_size: 2
optimizer:
learning_rate: 3E-4
weight_decay: 0.1
min_lr_ratio: 0.1
beta1: 0.9
beta2: 0.95
warmup: 2000
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ name = "levanter"
version = "1.2"
authors = [
{ name = "David Hall", email = "[email protected]" },
{ name = "Jason Wang"},
{ name = "Ahmed Ahmed"},
{ name = "Ivan Zhou", email = "[email protected]" },
{ name = "Will Held"},
{ name = "Virginia Adams"}
]
description = "Scalable Training for Foundation Models with Named Tensors and JAX"
readme = "README.md"
Expand Down Expand Up @@ -47,10 +51,11 @@ dependencies = [
"pydantic<3",
"rich~=13.0",
"filelock~=3.13",
# "ai2-olmo",
"async-lru~=2.0",
"tqdm-loggable>=0.2",
"deepdiff"
"deepdiff",
# "lm-eval==0.4.2",
"lm-eval @ git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch"
]

[project.urls]
Expand Down Expand Up @@ -100,11 +105,11 @@ markers = [
[project.optional-dependencies]
test = [
"pytest",
"pytest-forked",
"pytest-asyncio",
"flake8",
"soundfile",
"librosa",
"pytest-forked",
"pytest-asyncio",
]

[tool.setuptools.packages.find]
Expand Down
8 changes: 6 additions & 2 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,15 @@ def load_checkpoint(
logger.warning("Loading checkpoint in jit. This is not recommended and probably won't work.")

if discover_latest:
checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore
discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore
else:
discovered_checkpoint_path = checkpoint_path

if checkpoint_path is None or not fs.exists(checkpoint_path):
if discovered_checkpoint_path is None or not fs.exists(discovered_checkpoint_path):
raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}")

checkpoint_path = discovered_checkpoint_path

logger.info(f"Loading checkpoint from {checkpoint_path}")
metadata = load_metadata(checkpoint_path, fs)

Expand Down
Loading
Loading