Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add str method to datamodule #20301

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""LightningDataModule for loading DataLoaders with ease."""

import inspect
import os
from collections.abc import Iterable
from typing import IO, Any, Optional, Union, cast

Expand Down Expand Up @@ -244,3 +245,34 @@ def load_from_checkpoint(
**kwargs,
)
return cast(Self, loaded)

def __str__(self) -> str:
"""Return a string representation of the datasets that are setup.

Returns:
A string representation of the datasets that are setup.

"""
datasets_info: Optional[list[str]] = []

def len_implemented(obj: Dataset) -> bool:
try:
len(obj)
return True
except NotImplementedError:
return False

for attr_name in dir(self):
attr = getattr(self, attr_name)

# Get Dataset information
if isinstance(attr, Dataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this is that there are no guarantees that any datasets actually exist as attributes.

I think a better design would be to leverage the contract, and get dataloaders through self.train_dataloader (caveat: it may be a dataloader or an iterable of dataloaders), self.val_dataloader, etc, access their dataset property (https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) and produce a string from there in a similar way to the current way.

Also instead of ignoring non-Dataset subclasses in this case, we could call str from whatever is returned as dataset.

I think it's a valuable addition, wdyt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good idea. I also had a similar idea in the beginning but discarded it because I thought it would add too much overhead. But I did some research an DataLoader should not add too much overhead to the actual dataset so we can use it in the str method.
Your suggestion would make the method more save, while still maintaining the same functionality, since one has to call prepare and setup anyway before one is able to access the dataset information.
I will try to implement your suggestion

if hasattr(attr, "__len__") and len_implemented(attr):
datasets_info.append(f"name={attr_name}, size={len(attr)}")
else:
datasets_info.append(f"name={attr_name}, size=Unavailable")

if not datasets_info:
return "No datasets are set up."

return os.linesep.join(datasets_info)
32 changes: 32 additions & 0 deletions src/lightning/pytorch/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,38 @@ def predict_dataloader(self) -> DataLoader:
return DataLoader(self.random_predict)


class BoringDataModuleNoLen(LightningDataModule):
"""
.. warning:: This is meant for testing/debugging and is experimental.
"""

def __init__(self) -> None:
super().__init__()
self.random_full = RandomIterableDataset(32, 64 * 4)


class BoringDataModuleLenNotImplemented(LightningDataModule):
"""
.. warning:: This is meant for testing/debugging and is experimental.
"""

def __init__(self) -> None:
super().__init__()

class DS(Dataset):
def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index: int) -> Tensor:
return self.data[index]

def __len__(self) -> int:
raise NotImplementedError

self.random_full = DS(32, 64 * 4)


class ManualOptimBoringModel(BoringModel):
"""
.. warning:: This is meant for testing/debugging and is experimental.
Expand Down
69 changes: 68 additions & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from argparse import Namespace
from dataclasses import dataclass
Expand All @@ -22,7 +23,12 @@
import torch
from lightning.pytorch import LightningDataModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.demos.boring_classes import (
BoringDataModule,
BoringDataModuleLenNotImplemented,
BoringDataModuleNoLen,
BoringModel,
)
from lightning.pytorch.profilers.simple import SimpleProfiler
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities import AttributeDict
Expand Down Expand Up @@ -510,3 +516,64 @@ def prepare_data(self):
durations = profiler.recorded_durations[key]
assert len(durations) == 1
assert durations[0] > 0


def test_datamodule_string_no_datasets():
dm = BoringDataModule()
del dm.random_full
expected_output = "No datasets are set up."
assert str(dm) == expected_output


def test_datamodule_string_no_length():
dm = BoringDataModuleNoLen()
expected_output = "name=random_full, size=Unavailable"
assert str(dm) == expected_output


def test_datamodule_string_length_not_implemented():
dm = BoringDataModuleLenNotImplemented()
expected_output = "name=random_full, size=Unavailable"
assert str(dm) == expected_output


def test_datamodule_string_fit_setup():
dm = BoringDataModule()
dm.setup(stage="fit")

expected_output = (
f"name=random_full, size=256{os.linesep}" f"name=random_train, size=64{os.linesep}" f"name=random_val, size=64"
)
output = str(dm)

assert expected_output == output


def test_datamodule_string_validation_setup():
dm = BoringDataModule()
dm.setup(stage="validate")

expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_val, size=64"
output = str(dm)

assert expected_output == output


def test_datamodule_string_test_setup():
dm = BoringDataModule()
dm.setup(stage="test")

expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_test, size=64"
output = str(dm)

assert expected_output == output


def test_datamodule_string_predict_setup():
dm = BoringDataModule()
dm.setup(stage="predict")

expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_predict, size=64"
output = str(dm)

assert expected_output == output
Loading