-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better registration support for a wide range of third-party hardware #20349
base: master
Are you sure you want to change the base?
Conversation
Examples here: https://github.com/uniartisan/RWKV-PEFT/blob/device-enhance/train.py#L499 There are a lot of things to be checked, I will try to do it later and make it more clear in documentation |
f863645
to
baf3e5c
Compare
ae3ae6b
to
ce680a2
Compare
8f0b3d6
to
2a89640
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20349 +/- ##
=======================================
- Coverage 88% 88% -0%
=======================================
Files 267 267
Lines 23266 23321 +55
=======================================
+ Hits 20375 20418 +43
- Misses 2891 2903 +12
|
1c83154
to
15595bf
Compare
7fee905
to
4299dfe
Compare
4299dfe
to
decb98a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the interesting PR! I added a few comments.
@@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any: | |||
def get_parallel_devices(devices: Any) -> Any: | |||
"""Gets parallel devices for the Accelerator.""" | |||
|
|||
@staticmethod | |||
@abstractmethod | |||
def get_device() -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we name this get_device_type
instead? this would be consistent with the fact that in PyTorch x.device.type
is a string ("cpu", "cuda", etc).
if _TORCH_GREATER_EQUAL_2_4 | ||
else getattr( | ||
torch, | ||
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I understand this condition, can you please clarify?
@@ -49,13 +49,16 @@ class FSDPPrecision(Precision): | |||
|
|||
""" | |||
|
|||
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: | |||
def __init__( | |||
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is device_type
, since device
may not be a string and may have a device_id
appended to it
@@ -111,7 +114,9 @@ def module_init_context(self) -> AbstractContextManager: | |||
@override | |||
def forward_context(self) -> AbstractContextManager: | |||
if "mixed" in self.precision: | |||
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) | |||
return torch.autocast( | |||
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) | |
self.device_type, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) |
supported_precision = get_args(_PRECISION_INPUT) | ||
if precision not in supported_precision: | ||
raise ValueError( | ||
f"`precision={precision!r})` is not supported in FSDP." | ||
f" `precision` must be one of: {supported_precision}." | ||
) | ||
self.device = device if device is not None else "cuda" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.device = device if device is not None else "cuda" | |
self.device_type = device_type if device_type is not None else "cuda" |
@@ -124,7 +124,13 @@ def setup_module(self, module: Module) -> DistributedDataParallel: | |||
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" | |||
device_ids = self._determine_ddp_device_ids() | |||
# https://pytorch.org/docs/stable/notes/cuda.html#id5 | |||
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() | |||
ctx = ( | |||
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd avoid doing this inline. Better to getattr
and assign to a variable, and then use the variable.
My main question here is what is the contract that ensures that an accelerator has a concept of streams. Unless I read through the code, as a developer I wouldn't know that I need to register streams as torch.mygpu.stream
and torch.mygpu.Stream()
.
So we should either guard the strategy to only apply to "cuda"
, or introduce a stream contract to the accelerator. I'd much rather do the former.
@@ -507,7 +507,11 @@ def load_checkpoint( | |||
|
|||
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) | |||
|
|||
torch.cuda.empty_cache() | |||
if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device() != "cpu": | |||
getattr(torch, self.root_device.type.split(":")[0]).empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same goes for the comment above. empty_cache
is not part of a contract (nor is the fact that a device is registered as a submodule of the torch
module). It needs to be if we want to rely on calling emtpy_cache
on whatever device we pass.
BTW, being a torch
submodule is too strong of a requirement in my opinion.
In this case we should probably guard the strategy to be GPU-only.
What does this PR do?
Thank you to the lightning team for providing such an easy-to-use, clearly designed library.
The pr draft hopes to provide better registration support for a wide range of third-party hardware, and the pr is designed to integrate third-party hardware with minimal intrusive changes, including intel XPU, Huawei Ascend NPU, Cambrian, Moorethreads, and more.
Fixes #<issue_number>
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20349.org.readthedocs.build/en/20349/