-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add convert_module
to FSDP
#20323
base: master
Are you sure you want to change the base?
Add convert_module
to FSDP
#20323
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20323 +/- ##
=======================================
Coverage 88% 88%
=======================================
Files 267 267
Lines 23274 23285 +11
=======================================
+ Hits 20381 20392 +11
Misses 2893 2893 |
7a0c355
to
baeb535
Compare
Thank you @tshu-w! As a sanity check, can you verify that the issues in #19721 are resolved? (i.e. memory goes back to what PyTorch uses, and no inconsistency errors are produced - these may be good tests to add btw, or at least a scaled-down version thereof). I'll be happy to run things on my end and dig deeper in parallel. |
I indeed noticed a decrease in VRAM usage (which I will confirm again in the coming week), even when I initialize the LLM in def configure_model(self):
if self.model is not None:
return
self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path)
# suppress the open-end generation warning
self.model.generation_config.pad_token_id = (
self.model.generation_config.pad_token_id
or self.model.generation_config.eos_token_id
)
if self.hparams.peft_config:
peft_config = get_peft_config(self.hparams.peft_config)
self.model = get_peft_model(self.model, peft_config)
if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = (
self.chatml_template
if self.hparams.use_chatml_template
else self.base_template
)
if self.hparams.use_chatml_template:
self.tokenizer.add_tokens(
["<|im_start|>", "<|im_end|>"], special_tokens=True
)
self.model.resize_token_embeddings(len(self.tokenizer))
if self.hparams.ckpt_path:
checkpoint = torch.load(self.hparams.ckpt_path, weights_only=True)
self.load_state_dict(checkpoint["state_dict"]) |
hey @tshu-w did you end up digging further? |
Checking memory gains on my end |
What does this PR do?
Add
convert_module
for FSDP as DeepSpeed.Fixes #19721 (comment)
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20323.org.readthedocs.build/en/20323/