-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom dataloader registry support #2932
base: main
Are you sure you want to change the base?
Changes from 60 commits
7088e4b
cc72b05
46048e3
14f343d
17282cd
a4143f5
69abc47
dc21a3d
a1098b3
b07216b
a578af1
42434ec
3d0c890
eff5b1e
cbdc26e
18d65a6
4fe3ee1
1110966
7972bdc
2d86c43
3c44d86
c0889d8
8fe043c
d8cf0f6
c41e8b2
6ec5d4d
6bce317
1f4ae9d
de1f30b
49fa01e
e33a935
48627d9
609094d
8cf3517
ba5a028
a7dc3fe
f3ff0f8
083c76e
70bba69
8c75662
b6eb2f1
2979ea2
67e9b34
11fe33a
4a648ff
e3831cb
e1837bd
bf4d3bf
2cc8ff9
e62dc3a
41fd877
10ada9c
dd3649c
c6acb5a
b35c6eb
1801604
ed77a65
fc831d5
7400621
47376ca
f94f7fa
962f043
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
name: test (custom dataloaders) | ||
|
||
on: | ||
push: | ||
branches: [main, "[0-9]+.[0-9]+.x"] | ||
pull_request: | ||
branches: [main, "[0-9]+.[0-9]+.x"] | ||
types: [labeled, synchronize, opened] | ||
schedule: | ||
- cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
test: | ||
# if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered | ||
if: >- | ||
( | ||
contains(github.event.pull_request.labels.*.name, 'custom_dataloader') || | ||
contains(github.event.pull_request.labels.*.name, 'all tests') || | ||
contains(github.event_name, 'schedule') || | ||
contains(github.event_name, 'workflow_dispatch') | ||
) | ||
|
||
runs-on: ${{ matrix.os }} | ||
|
||
defaults: | ||
run: | ||
shell: bash -e {0} # -e to fail on error | ||
|
||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [ubuntu-latest] | ||
python: ["3.11"] | ||
|
||
name: integration | ||
|
||
env: | ||
OS: ${{ matrix.os }} | ||
PYTHON: ${{ matrix.python }} | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python }} | ||
cache: "pip" | ||
cache-dependency-path: "**/pyproject.toml" | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip wheel uv | ||
python -m uv pip install --system "scvi-tools[tests] @ ." | ||
python -m pip install scdataloader | ||
python -m pip install cellxgene-census | ||
python -m pip install tiledbsoma | ||
python -m pip install s3fs | ||
python -m pip install torchdata | ||
python -m pip install psutil | ||
python -m pip install lamindb | ||
python -m pip install bionty | ||
python -m pip install biomart | ||
|
||
- name: Install Specific Branch of Repository | ||
env: | ||
GH_TOKEN: ${{ secrets.GH_TOKEN }} | ||
run: | | ||
git config --global url."https://${GH_TOKEN}:[email protected]/".insteadOf "https://github.com/" | ||
git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git | ||
git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git | ||
|
||
- name: Run specific custom dataloader pytest | ||
env: | ||
MPLBACKEND: agg | ||
PLATFORM: ${{ matrix.os }} | ||
DISPLAY: :42 | ||
COLUMNS: 120 | ||
run: | | ||
coverage run -m pytest tests/dataloaders/test_custom_dataloader.py -v --color=yes --custom-dataloader-tests | ||
coverage report | ||
|
||
- uses: codecov/codecov-action@v4 | ||
with: | ||
token: ${{ secrets.CODECOV_TOKEN }} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,7 +118,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): | |
|
||
def __init__( | ||
self, | ||
adata: AnnData, | ||
adata: AnnData | None = None, | ||
registry: dict | None = None, | ||
n_hidden: int = 128, | ||
n_latent: int = 10, | ||
n_layers: int = 1, | ||
|
@@ -128,23 +129,29 @@ def __init__( | |
linear_classifier: bool = False, | ||
**model_kwargs, | ||
): | ||
super().__init__(adata) | ||
super().__init__(adata, registry) | ||
scanvae_model_kwargs = dict(model_kwargs) | ||
|
||
self._set_indices_and_labels() | ||
|
||
# ignores unlabeled catgegory | ||
n_labels = self.summary_stats.n_labels - 1 | ||
n_cats_per_cov = ( | ||
self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key | ||
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry | ||
else None | ||
) | ||
if adata is not None: | ||
n_cats_per_cov = ( | ||
self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key | ||
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry | ||
else None | ||
) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should include the case that CAT_COVS_KEY is in registry. |
||
# custom datamodule | ||
n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] | ||
if n_cats_per_cov == 0: | ||
n_cats_per_cov = None | ||
|
||
n_batch = self.summary_stats.n_batch | ||
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry | ||
use_size_factor_key = self.registry_["setup_args"][f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key"] | ||
library_log_means, library_log_vars = None, None | ||
if not use_size_factor_key and self.minified_data_type is None: | ||
if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you create a seperate PR and merge it that changes the if statement to use_observed_lib_size=False (see vae code). |
||
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) | ||
|
||
self.module = self._module_cls( | ||
|
@@ -187,6 +194,7 @@ def from_scvi_model( | |
unlabeled_category: str, | ||
labels_key: str | None = None, | ||
adata: AnnData | None = None, | ||
registry: dict | None = None, | ||
**scanvi_kwargs, | ||
): | ||
"""Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model. | ||
|
@@ -203,6 +211,8 @@ def from_scvi_model( | |
Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. | ||
adata | ||
AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. | ||
registry | ||
Registry of the datamodule used to train scANVI model. | ||
scanvi_kwargs | ||
kwargs for scANVI model | ||
""" | ||
|
@@ -237,7 +247,7 @@ def from_scvi_model( | |
# validate new anndata against old model | ||
scvi_model._validate_anndata(adata) | ||
|
||
scvi_setup_args = deepcopy(scvi_model.adata_manager.registry[_SETUP_ARGS_KEY]) | ||
scvi_setup_args = deepcopy(scvi_model.registry[_SETUP_ARGS_KEY]) | ||
scvi_labels_key = scvi_setup_args["labels_key"] | ||
if labels_key is None and scvi_labels_key is None: | ||
raise ValueError( | ||
|
@@ -250,7 +260,8 @@ def from_scvi_model( | |
unlabeled_category=unlabeled_category, | ||
**scvi_setup_args, | ||
) | ||
scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) | ||
|
||
scanvi_model = cls(adata, scvi_model.registry, **non_kwargs, **kwargs, **scanvi_kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it supported to have both adata and registry? |
||
scvi_state_dict = scvi_model.module.state_dict() | ||
scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) | ||
scanvi_model.was_pretrained = True | ||
|
@@ -259,7 +270,7 @@ def from_scvi_model( | |
|
||
def _set_indices_and_labels(self): | ||
"""Set indices for labeled and unlabeled cells.""" | ||
labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) | ||
labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) | ||
self.original_label_key = labels_state_registry.original_key | ||
self.unlabeled_category_ = labels_state_registry.unlabeled_category | ||
|
||
|
@@ -479,12 +490,15 @@ def setup_anndata( | |
NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), | ||
] | ||
# register new fields if the adata is minified | ||
adata_minify_type = _get_adata_minify_type(adata) | ||
if adata_minify_type is not None: | ||
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) | ||
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) | ||
adata_manager.register_fields(adata, **kwargs) | ||
cls.register_manager(adata_manager) | ||
if adata: | ||
adata_minify_type = _get_adata_minify_type(adata) | ||
if adata_minify_type is not None: | ||
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) | ||
adata_manager = AnnDataManager( | ||
fields=anndata_fields, setup_method_args=setup_method_args | ||
) | ||
adata_manager.register_fields(adata, **kwargs) | ||
cls.register_manager(adata_manager) | ||
|
||
@staticmethod | ||
def _get_fields_for_adata_minification( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this?