Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better registration support for a wide range of third-party hardware #20349

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

uniartisan
Copy link

@uniartisan uniartisan commented Oct 19, 2024

What does this PR do?

Thank you to the lightning team for providing such an easy-to-use, clearly designed library.

The pr draft hopes to provide better registration support for a wide range of third-party hardware, and the pr is designed to integrate third-party hardware with minimal intrusive changes, including intel XPU, Huawei Ascend NPU, Cambrian, Moorethreads, and more.

Fixes #<issue_number>

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--20349.org.readthedocs.build/en/20349/

@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels Oct 19, 2024
@uniartisan
Copy link
Author

Examples here: https://github.com/uniartisan/RWKV-PEFT/blob/device-enhance/train.py#L499

There are a lot of things to be checked, I will try to do it later and make it more clear in documentation

Copy link

codecov bot commented Oct 22, 2024

Codecov Report

Attention: Patch coverage is 81.69014% with 13 lines in your changes missing coverage. Please review.

Project coverage is 88%. Comparing base (1e88899) to head (31c3412).
Report is 2 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20349   +/-   ##
=======================================
- Coverage      88%      88%   -0%     
=======================================
  Files         267      267           
  Lines       23266    23321   +55     
=======================================
+ Hits        20375    20418   +43     
- Misses       2891     2903   +12     
---- 🚨 Try these New Features:

@uniartisan uniartisan force-pushed the device-enhance branch 2 times, most recently from 1c83154 to 15595bf Compare October 22, 2024 08:00
@uniartisan uniartisan force-pushed the device-enhance branch 5 times, most recently from 7fee905 to 4299dfe Compare October 22, 2024 09:49
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the interesting PR! I added a few comments.

@@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any:
def get_parallel_devices(devices: Any) -> Any:
"""Gets parallel devices for the Accelerator."""

@staticmethod
@abstractmethod
def get_device() -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we name this get_device_type instead? this would be consistent with the fact that in PyTorch x.device.type is a string ("cpu", "cuda", etc).

if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand this condition, can you please clarify?

@@ -49,13 +49,16 @@ class FSDPPrecision(Precision):

"""

def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
def __init__(
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is device_type, since device may not be a string and may have a device_id appended to it

@@ -111,7 +114,9 @@ def module_init_context(self) -> AbstractContextManager:
@override
def forward_context(self) -> AbstractContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return torch.autocast(
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
self.device_type, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)

supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)
self.device = device if device is not None else "cuda"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.device = device if device is not None else "cuda"
self.device_type = device_type if device_type is not None else "cuda"

@@ -124,7 +124,13 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
ctx = (
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd avoid doing this inline. Better to getattr and assign to a variable, and then use the variable.

My main question here is what is the contract that ensures that an accelerator has a concept of streams. Unless I read through the code, as a developer I wouldn't know that I need to register streams as torch.mygpu.stream and torch.mygpu.Stream().

So we should either guard the strategy to only apply to "cuda", or introduce a stream contract to the accelerator. I'd much rather do the former.

@@ -507,7 +507,11 @@ def load_checkpoint(

optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())

torch.cuda.empty_cache()
if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device() != "cpu":
getattr(torch, self.root_device.type.split(":")[0]).empty_cache()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for the comment above. empty_cache is not part of a contract (nor is the fact that a device is registered as a submodule of the torch module). It needs to be if we want to rely on calling emtpy_cache on whatever device we pass.

BTW, being a torch submodule is too strong of a requirement in my opinion.

In this case we should probably guard the strategy to be GPU-only.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator docs Documentation related fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants