Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError for Phi-3.5-mini-instruct and Qwen2.5-7B-Instruct with NeMo + ThunderFX #1476

Open
mpatel31415 opened this issue Nov 26, 2024 · 6 comments · May be fixed by #1480
Open

AssertionError for Phi-3.5-mini-instruct and Qwen2.5-7B-Instruct with NeMo + ThunderFX #1476

mpatel31415 opened this issue Nov 26, 2024 · 6 comments · May be fixed by #1480
Assignees
Labels
mixology Issues that the mixology team has surfaced nemo Issues needed to support NVIDIA NeMo models. thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@mpatel31415
Copy link
Contributor

mpatel31415 commented Nov 26, 2024

🐛 Bug

When running Phi-3.5-mini-instruct and Qwen2.5-7B-Instruct with NeMo + ThunderFX we get error:

0: File "/usr/lib/python3.10/copy.py", line 153, in deepcopy
0: y = copier(memo)
0: File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 793, in deepcopy
0: fake_mod = _CodeOnlyModule(copy.deepcopy(self.dict, memo))
0: File "/usr/lib/python3.10/copy.py", line 146, in deepcopy
0: y = copier(x, memo)
0: File "/usr/lib/python3.10/copy.py", line 231, in _deepcopy_dict
0: y[deepcopy(key, memo)] = deepcopy(value, memo)
0: File "/usr/lib/python3.10/copy.py", line 153, in deepcopy
0: y = copier(memo)
0: File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph.py", line 940, in deepcopy
0: assert isinstance(output_vals, tuple)
0: torch._dynamo.exc.BackendCompilerFailed: backend='<thunder.dynamo.compiler.ThunderCompiler object at 0x7ffd30242dd0>' raised:
0: AssertionError:

(I'll add file with full traceback)

To Reproduce

The error is present on 1xH100.

Dockerfile used (I build it yesterday and I'm not sure yet how nemo:dev images are versioned, so I can't provide its detailed version):

FROM nvcr.io/nvidia/nemo:dev
ARG NVFUSER_REPO=git+https://github.com/NVIDIA/Fuser.git
ARG THUNDER_REPO=git+https://github.com/Lightning-AI/lightning-thunder.git

# Add cloned NeMo latest code
RUN git clone --recursive https://github.com/NVIDIA/NeMo.git /NeMo_cloned
RUN (cd /NeMo_cloned && python -m pip install .)


# Install requirements needed for NeMo, Thunder and NVFUser.
# We must install them in such compilated way because otherwise Thunder is not 
# updated and we are not able to use the latest version. 
RUN python -m pip install -r /NeMo_cloned/requirements/requirements_lightning.txt && \
    python -m pip install --upgrade ${NVFUSER_REPO}  && \
    python -m pip install --upgrade ${THUNDER_REPO} && \
    python -m pip install --upgrade --no-deps --force-reinstall ${NVFUSER_REPO} && \
    python -m pip install --upgrade --no-deps --force-reinstall ${THUNDER_REPO}
 
# Install Mixology requirements (this can be skipped, so I'm commenting it out)
# COPY requirements/mixology.txt mixology_requirements.txt
# RUN pip install --upgrade -r mixology_requirements.txt

Inside docker container please run:

model=microsoft/Phi-3.5-mini-instruct
# Download the model (you might need to set HF_TOKEN and agree on the website to terms of use of this model)
huggingface-cli download $model --local-dir checkpoints/$model --cache-dir checkpoints/$model 
# Run benchmark
python bench_targets/llm_peft/_nemo.py --model checkpoints/$model --mbs 1 --seq-length 2048 --jit-backend thunder

Script bench_targets/llm_peft/_nemo.py can be obtained from internal Gitlab from akoumparouli/nemo_bench. You can contact me or @tfogal if you have any questions.

You can check that the command below works:

python bench_targets/llm_peft/_nemo.py --model checkpoints/$model --mbs 1 --seq-length 2048 --jit-backend eager

Expected behavior

No error for Thunder.

Environment

cc @tfogal

@mpatel31415
Copy link
Contributor Author

Here is txt file with full traceback: full_traceback.txt

@kiya00
Copy link
Collaborator

kiya00 commented Nov 26, 2024

I think the reason is this PR(#1437), it relies on PyTorch's bug fixing pytorch/pytorch#139275, probably only in Torch nightly

@IvanYashchuk IvanYashchuk added nemo Issues needed to support NVIDIA NeMo models. mixology Issues that the mixology team has surfaced labels Nov 26, 2024
@IvanYashchuk
Copy link
Collaborator

The error is fixed only with the latest PyTorch (Nov 1st+, pytorch/pytorch@0cf4cc3). What's the PyTorch version used in nvcr.io/nvidia/nemo:dev?

@tfogal tfogal added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Nov 26, 2024
@tfogal
Copy link
Collaborator

tfogal commented Nov 26, 2024

I think the reason is this PR(#1437), it relies on PyTorch's bug fixing pytorch/pytorch#139275, probably only in Torch nightly

The functionality added in #1437 is not (yet) a blocker for our Q4 goals. I recommend a workaround that simply disables the functionality when/if PyTorch is too old.

The error is fixed only with the latest PyTorch (Nov 1st+, pytorch/pytorch@0cf4cc3). What's the PyTorch version used in nvcr.io/nvidia/nemo:dev?

It is old: 2.4.0a0+3bcc3cddb5.nv24.07.

@IvanYashchuk
Copy link
Collaborator

I recommend a workaround that simply disables the functionality when/if PyTorch is too old.

Sure, if we need to make it work for the older PyTorch we can do that.

A workaround could be to iterate over all submodules returned in split_module used here

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)

and add an output node to all submodules that are missing one. @kshitij12345, does this sound like a correct workaround?

@kshitij12345
Copy link
Collaborator

Yes, I think that should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mixology Issues that the mixology team has surfaced nemo Issues needed to support NVIDIA NeMo models. thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants