Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom dataloader registry support #2932

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
7088e4b
copying CZI custom dataloader into our repo
ori-kron-wis Jul 28, 2024
cc72b05
added some fixes to the custom dataloader stuff
ori-kron-wis Jul 30, 2024
46048e3
Some suggestions
canergen Jul 30, 2024
14f343d
Changes to datamodule pipeline
canergen Jul 31, 2024
17282cd
Fixed attr_dict
canergen Jul 31, 2024
a4143f5
added some fixes based on custom data loader test
ori-kron-wis Aug 1, 2024
69abc47
Changes to dataloader
canergen Aug 6, 2024
dc21a3d
copying CZI custom dataloader into our repo
ori-kron-wis Jul 28, 2024
a1098b3
added some fixes to the custom dataloader stuff
ori-kron-wis Jul 30, 2024
b07216b
Some suggestions
canergen Jul 30, 2024
a578af1
Changes to datamodule pipeline
canergen Jul 31, 2024
42434ec
Fixed attr_dict
canergen Jul 31, 2024
3d0c890
added some fixes based on custom data loader test
ori-kron-wis Aug 1, 2024
eff5b1e
Changes to dataloader
canergen Aug 6, 2024
cbdc26e
Merge remote-tracking branch 'origin/ori-2907-custom-dataloader-regis…
ori-kron-wis Aug 7, 2024
18d65a6
add changes to tests and some merging with main following custom data…
ori-kron-wis Aug 7, 2024
4fe3ee1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
1110966
just put the cutom dataloder2 test under remarks so hook tests will r…
ori-kron-wis Aug 7, 2024
7972bdc
fixes
ori-kron-wis Aug 7, 2024
2d86c43
additional external models fixes once there is a registry
ori-kron-wis Aug 7, 2024
3c44d86
fixed a few failed tests
ori-kron-wis Aug 11, 2024
c0889d8
fix archesmixin init and added new custom dataloader test and github …
ori-kron-wis Aug 11, 2024
8fe043c
fix again for from __future__ import annotations
ori-kron-wis Aug 11, 2024
d8cf0f6
fix for run custom dataloader in github action
ori-kron-wis Aug 11, 2024
c41e8b2
rollback
ori-kron-wis Aug 11, 2024
6ec5d4d
added label to the new githubaction for custom dataloader
ori-kron-wis Aug 11, 2024
6bce317
fix for github action for custom dataloaders
ori-kron-wis Aug 12, 2024
1f4ae9d
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
de1f30b
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
49fa01e
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
e33a935
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
48627d9
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
609094d
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
8cf3517
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
ba5a028
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
a7dc3fe
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
f3ff0f8
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
083c76e
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 9, 2024
70bba69
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 15, 2024
8c75662
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 16, 2024
b6eb2f1
Returned REGISTRY_KEYS for import, after was drop in recent merges
ori-kron-wis Sep 16, 2024
2979ea2
It is ok to drop it after scarches categorial covariates fix
ori-kron-wis Sep 16, 2024
67e9b34
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 17, 2024
11fe33a
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 17, 2024
4a648ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
e3831cb
moved to type checking blocks beucase of ruff updates
ori-kron-wis Sep 17, 2024
e1837bd
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 26, 2024
bf4d3bf
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Oct 7, 2024
2cc8ff9
updated for CZI custom dataloader test and backend
ori-kron-wis Oct 9, 2024
e62dc3a
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Oct 9, 2024
41fd877
added cellxgene-census folder as well for debug (will not be merged)
ori-kron-wis Oct 9, 2024
10ada9c
added cellxgene-census packge to run test
ori-kron-wis Oct 9, 2024
dd3649c
added torchdata packge to run test
ori-kron-wis Oct 9, 2024
c6acb5a
fixed the test workwflow
ori-kron-wis Oct 9, 2024
b35c6eb
adding the lamindb as well
ori-kron-wis Oct 10, 2024
1801604
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
ed77a65
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
fc831d5
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
7400621
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
47376ca
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
f94f7fa
removed redundat functions in code base
ori-kron-wis Oct 13, 2024
962f043
Added scanvi support, including CZI datamodule fix for it
ori-kron-wis Oct 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
DISPLAY: :42
COLUMNS: 120
run: |
coverage run -m pytest -v --color=yes
coverage run -m pytest -v --color=yes -m "not custom_dataloader"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this?

coverage report

- uses: codecov/codecov-action@v4
Expand Down
89 changes: 89 additions & 0 deletions .github/workflows/test_linux_custom_dataloader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
name: test (custom dataloaders)

on:
push:
branches: [main, "[0-9]+.[0-9]+.x"]
pull_request:
branches: [main, "[0-9]+.[0-9]+.x"]
types: [labeled, synchronize, opened]
schedule:
- cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
# if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered
if: >-
(
contains(github.event.pull_request.labels.*.name, 'custom_dataloader') ||
contains(github.event.pull_request.labels.*.name, 'all tests') ||
contains(github.event_name, 'schedule') ||
contains(github.event_name, 'workflow_dispatch')
)

runs-on: ${{ matrix.os }}

defaults:
run:
shell: bash -e {0} # -e to fail on error

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.11"]

name: integration

env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python }}

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"

- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel uv
python -m uv pip install --system "scvi-tools[tests] @ ."
python -m pip install scdataloader
python -m pip install cellxgene-census
python -m pip install tiledbsoma
python -m pip install s3fs
python -m pip install torchdata
python -m pip install psutil
python -m pip install lamindb
python -m pip install bionty
python -m pip install biomart

- name: Install Specific Branch of Repository
env:
GH_TOKEN: ${{ secrets.GH_TOKEN }}
run: |
git config --global url."https://${GH_TOKEN}:[email protected]/".insteadOf "https://github.com/"
git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git
git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git

- name: Run specific custom dataloader pytest
env:
MPLBACKEND: agg
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
COLUMNS: 120
run: |
coverage run -m pytest tests/dataloaders/test_custom_dataloader.py -v --color=yes --custom-dataloader-tests
coverage report

- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
1 change: 1 addition & 0 deletions cellxgene-census
Submodule cellxgene-census added at 9ae33a
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ tutorials = [
"scvi-tools[optional]",
"squidpy",
]
dataloaders = [
"scdataloader"
]

all = ["scvi-tools[dev,docs,tutorials]"]

Expand Down
10 changes: 10 additions & 0 deletions src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import scipy.sparse as sp_sparse
from anndata import AnnData

from scvi.utils import attrdict

try:
# anndata >= 0.10
from anndata.experimental import CSCDataset, CSRDataset
Expand Down Expand Up @@ -162,6 +164,14 @@ def _set_data_in_registry(
setattr(adata, attr_name, attribute)


def _get_summary_stats_from_registry(registry: dict) -> attrdict:
summary_stats = {}
for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values():
field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY]
summary_stats.update(field_summary_stats)
return attrdict(summary_stats)


def _verify_and_correct_data_format(adata: AnnData, attr_name: str, attr_key: str | None):
"""Check data format and correct if necessary.

Expand Down
3 changes: 2 additions & 1 deletion src/scvi/external/stereoscope/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class RNAStereoscope(UnsupervisedTrainingMixin, BaseModelClass):

def __init__(
self,
sc_adata: AnnData,
sc_adata: AnnData | None = None,
registry: dict | None = None,
**model_kwargs,
):
super().__init__(sc_adata)
Expand Down
1 change: 1 addition & 0 deletions src/scvi/external/stereoscope/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
n_spots: int,
sc_params: tuple[np.ndarray],
prior_weight: Literal["n_obs", "minibatch"] = "n_obs",
**model_kwargs,
):
super().__init__()
# unpack and copy parameters
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/_amortizedlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class AmortizedLDA(PyroSviTrainMixin, BaseModelClass):

def __init__(
self,
adata: AnnData,
adata: AnnData | None = None,
registry: dict | None = None,
n_topics: int = 20,
n_hidden: int = 128,
cell_topic_prior: float | Sequence[float] | None = None,
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/_autozi.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):

def __init__(
self,
adata: AnnData,
adata: AnnData | None = None,
registry: dict | None = None,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass)

def __init__(
self,
adata: AnnData,
adata: AnnData | None = None,
registry: dict | None = None,
n_hidden: int = 128,
n_latent: int = 5,
n_layers: int = 2,
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/_jaxscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class JaxSCVI(JaxTrainingMixin, BaseModelClass):

def __init__(
self,
adata: AnnData,
adata: AnnData | None = None,
registry: dict | None = None,
n_hidden: int = 128,
n_latent: int = 10,
dropout_rate: float = 0.1,
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/_linear_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClas

def __init__(
self,
adata: AnnData,
adata: AnnData | None = None,
registry: dict | None = None,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
Expand Down
7 changes: 4 additions & 3 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):

def __init__(
self,
adata: AnnData,
n_genes: int,
n_regions: int,
adata: AnnData | None = None,
registry: dict | None = None,
n_genes: int | None = None,
n_regions: int | None = None,
modality_weights: Literal["equal", "cell", "universal"] = "equal",
modality_penalty: Literal["Jeffreys", "MMD", "None"] = "Jeffreys",
n_hidden: int | None = None,
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ class PEAKVI(ArchesMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):

def __init__(
self,
adata: AnnData,
adata: AnnData | None = None,
registry: dict | None = None,
n_hidden: int | None = None,
n_latent: int | None = None,
n_layers_encoder: int = 2,
Expand Down
50 changes: 32 additions & 18 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass):

def __init__(
self,
adata: AnnData,
adata: AnnData | None = None,
registry: dict | None = None,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
Expand All @@ -128,23 +129,29 @@ def __init__(
linear_classifier: bool = False,
**model_kwargs,
):
super().__init__(adata)
super().__init__(adata, registry)
scanvae_model_kwargs = dict(model_kwargs)

self._set_indices_and_labels()

# ignores unlabeled catgegory
n_labels = self.summary_stats.n_labels - 1
n_cats_per_cov = (
self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)
if adata is not None:
n_cats_per_cov = (
self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include the case that CAT_COVS_KEY is in registry.

# custom datamodule
n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"]
if n_cats_per_cov == 0:
n_cats_per_cov = None

n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
use_size_factor_key = self.registry_["setup_args"][f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key"]
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
if self.adata is not None and not use_size_factor_key and self.minified_data_type is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a seperate PR and merge it that changes the if statement to use_observed_lib_size=False (see vae code).

library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
Expand Down Expand Up @@ -187,6 +194,7 @@ def from_scvi_model(
unlabeled_category: str,
labels_key: str | None = None,
adata: AnnData | None = None,
registry: dict | None = None,
**scanvi_kwargs,
):
"""Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model.
Expand All @@ -203,6 +211,8 @@ def from_scvi_model(
Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi.
adata
AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`.
registry
Registry of the datamodule used to train scANVI model.
scanvi_kwargs
kwargs for scANVI model
"""
Expand Down Expand Up @@ -237,7 +247,7 @@ def from_scvi_model(
# validate new anndata against old model
scvi_model._validate_anndata(adata)

scvi_setup_args = deepcopy(scvi_model.adata_manager.registry[_SETUP_ARGS_KEY])
scvi_setup_args = deepcopy(scvi_model.registry[_SETUP_ARGS_KEY])
scvi_labels_key = scvi_setup_args["labels_key"]
if labels_key is None and scvi_labels_key is None:
raise ValueError(
Expand All @@ -250,7 +260,8 @@ def from_scvi_model(
unlabeled_category=unlabeled_category,
**scvi_setup_args,
)
scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)

scanvi_model = cls(adata, scvi_model.registry, **non_kwargs, **kwargs, **scanvi_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it supported to have both adata and registry?

scvi_state_dict = scvi_model.module.state_dict()
scanvi_model.module.load_state_dict(scvi_state_dict, strict=False)
scanvi_model.was_pretrained = True
Expand All @@ -259,7 +270,7 @@ def from_scvi_model(

def _set_indices_and_labels(self):
"""Set indices for labeled and unlabeled cells."""
labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
self.original_label_key = labels_state_registry.original_key
self.unlabeled_category_ = labels_state_registry.unlabeled_category

Expand Down Expand Up @@ -479,12 +490,15 @@ def setup_anndata(
NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
]
# register new fields if the adata is minified
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
if adata:
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_adata_minification(
Expand Down
Loading
Loading