From 7088e4bd7909841896af418f99b1c87d58e1dd21 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 28 Jul 2024 17:04:41 +0300 Subject: [PATCH 01/53] copying CZI custom dataloader into our repo --- src/scvi/dataloaders/_custom_dataloader.py | 1298 +++++++++++++++++++ src/scvi/model/_scvi.py | 1 + tests/dataloaders/test_custom_dataloader.py | 1 + 3 files changed, 1300 insertions(+) create mode 100644 src/scvi/dataloaders/_custom_dataloader.py create mode 100644 tests/dataloaders/test_custom_dataloader.py diff --git a/src/scvi/dataloaders/_custom_dataloader.py b/src/scvi/dataloaders/_custom_dataloader.py new file mode 100644 index 0000000000..b22c697c2a --- /dev/null +++ b/src/scvi/dataloaders/_custom_dataloader.py @@ -0,0 +1,1298 @@ +from __future__ import annotations + +import abc +import gc +import logging +import os +import threading +from collections import deque +from collections.abc import Iterator, Sequence +from concurrent import futures +from concurrent.futures import Future +from contextlib import contextmanager +from datetime import timedelta +from math import ceil +from time import time +from typing import Any, TypeVar + +import numpy as np +import numpy.typing as npt +import pandas as pd +import psutil +import scipy +import tiledbsoma as soma +import torch +import torchdata.datapipes.iter as pipes +from attr import define +from lightning.pytorch import LightningDataModule +from numpy.random import Generator +from scipy import sparse +from sklearn.preprocessing import LabelEncoder +from torch import Tensor +from torch import distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +pytorch_logger = logging.getLogger("cellxgene_census.experimental.pytorch") + +# TODO: Rename to reflect the correct order of the Tensors within the tuple: (X, obs) +ObsAndXDatum = tuple[Tensor, Tensor] +"""Return type of ``ExperimentDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of +``X`` matrix row(s).The Tensors are rank 1 if ``batch_size`` is 1, +otherwise the Tensors are rank 2.""" + +util_logger = logging.getLogger("cellxgene_census.experimental.util") + +_T = TypeVar("_T") + + +DEFAULT_TILEDB_CONFIGURATION: dict[str, Any] = { + # https://docs.tiledb.com/main/how-to/configuration#configuration-parameters + "py.init_buffer_bytes": 1 * 1024**3, + "soma.init_buffer_bytes": 1 * 1024**3, + # S3 requests should not be signed, since we want to allow anonymous access + "vfs.s3.no_sign_request": "true", + "vfs.s3.region": "us-west-2", +} + + +def get_default_soma_context( + tiledb_config: dict[str, Any] | None = None, +) -> soma.options.SOMATileDBContext: + """Return a :class:`tiledbsoma.SOMATileDBContext` with sensible defaults that can be further + + customized by the user. The customized context can then be passed to + :func:`cellxgene_census.open_soma` with the ``context`` argument or to + :meth:`somacore.SOMAObject.open` with the ``context`` argument, such as + :meth:`tiledbsoma.Experiment.open`. Use the :meth:`tiledbsoma.SOMATileDBContext.replace` + method on the returned object to customize its settings further. + + Args: + tiledb_config: + A dictionary of TileDB configuration parameters. If specified, the parameters will + override the defaults. If not specified, the default configuration will be returned. + + Returns + ------- + A :class:`tiledbsoma.SOMATileDBContext` object with sensible defaults. + + Examples + -------- + To reduce the amount of memory used by TileDB-SOMA I/O operations: + + .. highlight:: python + .. code-block:: python + + ctx = cellxgene_census.get_default_soma_context( + tiledb_config={ + "py.init_buffer_bytes": 128 * 1024**2, + "soma.init_buffer_bytes": 128 * 1024**2, + } + ) + c = census.open_soma(uri="s3://my-private-bucket/census/soma", context=ctx) + + To access a copy of the Census located in a private bucket that is located in a different + S3 region, use: + + .. highlight:: python + .. code-block:: python + + ctx = cellxgene_census.get_default_soma_context( + tiledb_config={"vfs.s3.no_sign_request": "false", "vfs.s3.region": "us-east-1"} + ) + c = census.open_soma(uri="s3://my-private-bucket/census/soma", context=ctx) + + Lifecycle: + experimental + """ + tiledb_config = dict(DEFAULT_TILEDB_CONFIGURATION, **(tiledb_config or {})) + return soma.options.SOMATileDBContext().replace(tiledb_config=tiledb_config) + + +class _EagerIterator(Iterator[_T]): + def __init__( + self, + iterator: Iterator[_T], + pool: futures.Executor | None = None, + ): + super().__init__() + self.iterator = iterator + self._pool = pool or futures.ThreadPoolExecutor() + self._own_pool = pool is None + self._future: Future[_T] | None = None + self._begin_next() + + def _begin_next(self) -> None: + self._future = self._pool.submit(self.iterator.__next__) + util_logger.debug("Fetching next iterator element, eagerly") + + def __next__(self) -> _T: + try: + assert self._future + res = self._future.result() + self._begin_next() + return res + except StopIteration: + self._cleanup() + raise + + def _cleanup(self) -> None: + util_logger.debug("Cleaning up eager iterator") + if self._own_pool: + self._pool.shutdown() + + def __del__(self) -> None: + # Ensure the threadpool is cleaned up in the case where the + # iterator is not exhausted. For more information on __del__: + # https://docs.python.org/3/reference/datamodel.html#object.__del__ + self._cleanup() + super_del = getattr(super(), "__del__", lambda: None) + super_del() + + +class _EagerBufferedIterator(Iterator[_T]): + def __init__( + self, + iterator: Iterator[_T], + max_pending: int = 1, + pool: futures.Executor | None = None, + ): + super().__init__() + self.iterator = iterator + self.max_pending = max_pending + self._pool = pool or futures.ThreadPoolExecutor() + self._own_pool = pool is None + self._pending_results: deque[futures.Future[_T]] = deque() + self._lock = threading.Lock() + self._begin_next() + + def __next__(self) -> _T: + try: + res = self._pending_results[0].result() + self._pending_results.popleft() + self._begin_next() + return res + except StopIteration: + self._cleanup() + raise + + def _begin_next(self) -> None: + def _fut_done(fut: futures.Future[_T]) -> None: + util_logger.debug("Finished fetching next iterator element, eagerly") + if fut.exception() is None: + self._begin_next() + + with self._lock: + not_running = len(self._pending_results) == 0 or self._pending_results[-1].done() + if len(self._pending_results) < self.max_pending and not_running: + _future = self._pool.submit(self.iterator.__next__) + util_logger.debug("Fetching next iterator element, eagerly") + _future.add_done_callback(_fut_done) + self._pending_results.append(_future) + assert len(self._pending_results) <= self.max_pending + + def _cleanup(self) -> None: + util_logger.debug("Cleaning up eager iterator") + if self._own_pool: + self._pool.shutdown() + + def __del__(self) -> None: + # Ensure the threadpool is cleaned up in the case where the + # iterator is not exhausted. For more information on __del__: + # https://docs.python.org/3/reference/datamodel.html#object.__del__ + self._cleanup() + super_del = getattr(super(), "__del__", lambda: None) + super_del() + + +class Encoder(abc.ABC): + """Base class for obs encoders. + + To define a custom encoder, two methods must be implemented: + + - ``register``: defines how the encoder will be fitted to the data. + - ``transform``: defines how the encoder will be applied to the data + in order to create an obs_tensor. + + See the implementation of ``DefaultEncoder`` for an example. + """ + + @abc.abstractmethod + def register(self, obs: pd.DataFrame) -> None: + """Register the encoder with obs.""" + pass + + @abc.abstractmethod + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + """Transform the obs DataFrame into a DataFrame of encoded values.""" + pass + + @property + def name(self) -> str: + return self.__class__.__name__ + + +class DefaultEncoder(Encoder): + """Default encoder based on LabelEncoder.""" + + def __init__(self, col: str) -> None: + self._encoder = LabelEncoder() + self.col = col + + def register(self, obs: pd.DataFrame) -> None: + self._encoder.fit(obs[self.col].unique()) + + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + return self._encoder.transform(df[self.col]) # type: ignore + + @property + def name(self) -> str: + return self.col + + @property + def classes_(self): # type: ignore + return self._encoder.classes_ + + +@define +class _SOMAChunk: + """Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the + + respective rows from the ``X`` matrix. + + Lifecycle: + experimental + """ + + obs: pd.DataFrame + X: scipy.sparse.spmatrix + stats: Stats + + def __len__(self) -> int: + return len(self.obs) + + +Encoders = dict[str, LabelEncoder] +"""A dictionary of ``LabelEncoder``s keyed by the ``obs`` column name.""" + + +@define +class Stats: + """Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API. This is useful + + for assessing the read throughput of SOMA data. + + Lifecycle: + experimental + """ + + n_obs: int = 0 + """The total number of obs rows retrieved""" + + nnz: int = 0 + """The total number of values retrieved""" + + elapsed: int = 0 + """The total elapsed time in seconds for retrieving all batches""" + + n_soma_chunks: int = 0 + """The number of chunks retrieved""" + + def __str__(self) -> str: + return ( + f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " + f"elapsed={timedelta(seconds=self.elapsed)}" + ) + + def __add__(self, other: Stats) -> Stats: + self.n_obs += other.n_obs + self.nnz += other.nnz + self.elapsed += other.elapsed + self.n_soma_chunks += other.n_soma_chunks + return self + + +@contextmanager +def _open_experiment( + uri: str, + aws_region: str | None = None, +) -> soma.Experiment: + """Internal method for opening a SOMA ``Experiment`` as a context manager.""" + context = get_default_soma_context().replace( + tiledb_config={"vfs.s3.region": aws_region} if aws_region else {} + ) + + with soma.Experiment.open(uri, context=context) as exp: + yield exp + + +class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]): + """Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class, + + not intended for public use. + """ + + X: soma.SparseNDArray + """A handle to the full X data of the SOMA ``Experiment``""" + + obs_joinids_chunks_iter: Iterator[npt.NDArray[np.int64]] + + var_joinids: npt.NDArray[np.int64] + """The ``var`` joinids to be retrieved from the SOMA ``Experiment``""" + + def __init__( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_column_names: Sequence[str], + obs_joinids_chunked: list[npt.NDArray[np.int64]], + var_joinids: npt.NDArray[np.int64], + shuffle_chunk_count: int | None = None, + shuffle_rng: Generator | None = None, + ): + self.obs = obs + self.X = X + self.obs_column_names = obs_column_names + if shuffle_chunk_count: + assert shuffle_rng is not None + + # At the start of this step, `obs_joinids_chunked` is a list of one dimensional + # numpy arrays. Each numpy array corresponds to a chunk of contiguous rows in `obs`. + # Critically, `obs_joinids_chunked` is randomly ordered where each chunk is + # from a random section of `obs`. + # We then take `shuffle_chunk_count` of these in order, concatenate them into + # a larger numpy array and shuffle this larger numpy array. + # The result is again a list of numpy arrays. + self.obs_joinids_chunks_iter = ( + shuffle_rng.permutation(np.concatenate(grouped_chunks)) + for grouped_chunks in list_split(obs_joinids_chunked, shuffle_chunk_count) + ) + else: + self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) + self.var_joinids = var_joinids + self.shuffle_chunk_count = shuffle_chunk_count + + def __next__(self) -> _SOMAChunk: + pytorch_logger.debug("Retrieving next SOMA chunk...") + start_time = time() + + # If no more chunks to iterate through, raise StopIteration, as all iterators + # do when at end + obs_joinids_chunk = next(self.obs_joinids_chunks_iter) + + obs_batch = ( + self.obs.read( + coords=(obs_joinids_chunk,), + column_names=self.obs_column_names, + ) + .concat() + .to_pandas() + .set_index("soma_joinid") + ) + assert obs_batch.shape[0] == obs_joinids_chunk.shape[0] + + # handle case of empty result (first batch has 0 rows) + if len(obs_batch) == 0: + raise StopIteration + + # reorder obs rows to match obs_joinids_chunk ordering, which may be shuffled + obs_batch = obs_batch.reindex(obs_joinids_chunk, copy=False) + + # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse + # matrix, but the blockwise iteration feature is not used (block_size is set to retrieve + # the chunk as a single block) + scipy_iter = ( + self.X.read(coords=(obs_joinids_chunk, self.var_joinids)) + .blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) + .scipy(compress=True) + ) + X_batch, _ = next(scipy_iter) + assert obs_batch.shape[0] == X_batch.shape[0] + + stats = Stats() + stats.n_obs += X_batch.shape[0] + stats.nnz += X_batch.nnz + stats.elapsed += int(time() - start_time) + stats.n_soma_chunks += 1 + + pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}") + return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) + + +def list_split(arr_list: list[Any], sublist_len: int) -> list[list[Any]]: + """Splits a python list into a list of sublists where each sublist is of size `sublist_len`. + + TODO: Replace with `itertools.batched` when Python 3.12 becomes the minimum supported version. + """ + i = 0 + result = [] + while i < len(arr_list): + if (i + sublist_len) >= len(arr_list): + result.append(arr_list[i:]) + else: + result.append(arr_list[i : i + sublist_len]) + + i += sublist_len + + return result + + +def run_gc() -> tuple[tuple[Any, Any, Any], tuple[Any, Any, Any]]: + proc = psutil.Process(os.getpid()) + + pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + gc.collect() + post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + + pytorch_logger.debug(f"gc: pre={pre_gc}") + pytorch_logger.debug(f"gc: post={post_gc}") + + return pre_gc, post_gc + + +class _ObsAndXIterator(Iterator[ObsAndXDatum]): + """Iterates through a set of ``obs`` and corresponding ``X`` rows, where the rows to be + + returned are specified by the ``obs_tables_iter`` argument. For the specified ``obs` rows, + the corresponding ``X`` data is loaded and joined together. It is returned from this iterator + as 2-tuples of ``X`` and obs Tensors. + + Internally manages the retrieval of data in SOMA-sized chunks, fetching the next chunk of SOMA + data as needed. Supports fetching the data in an eager manner, where the next SOMA chunk is + fetched while the current chunk is being read. This is an internal class, not intended for + public use. + """ + + soma_chunk_iter: _SOMAChunk | None + """The iterator for SOMA chunks of paired obs and X data""" + + soma_chunk: _SOMAChunk | None + """The current SOMA chunk of obs and X data""" + + i: int = -1 + """Index into current obs ``SOMA`` chunk""" + + def __init__( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_column_names: Sequence[str], + obs_joinids_chunked: list[npt.NDArray[np.int64]], + var_joinids: npt.NDArray[np.int64], + batch_size: int, + encoders: list[Encoder], + stats: Stats, + return_sparse_X: bool, + use_eager_fetch: bool, + shuffle_chunk_count: int | None = None, + shuffle_rng: Generator | None = None, + ) -> None: + self.soma_chunk_iter = _ObsAndXSOMAIterator( + obs, + X, + obs_column_names, + obs_joinids_chunked, + var_joinids, + shuffle_chunk_count, + shuffle_rng, + ) + if use_eager_fetch: + self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) + self.soma_chunk = None + self.var_joinids = var_joinids + self.batch_size = batch_size + self.return_sparse_X = return_sparse_X + self.encoders = encoders + self.stats = stats + self.max_process_mem_usage_bytes = 0 + self.X_dtype = X.schema[2].type.to_pandas_dtype() + + def __next__(self) -> ObsAndXDatum: + """Read the next torch batch, possibly across multiple soma chunks.""" + obs: pd.DataFrame = pd.DataFrame() + X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype) + + while len(obs) < self.batch_size: + try: + obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs)) + obs = pd.concat([obs, obs_partial], axis=0) + X = sparse.vstack([X, X_partial]) + except StopIteration: + break + + if len(obs) == 0: + raise StopIteration + + obs_encoded = pd.DataFrame() + + for enc in self.encoders: + obs_encoded[enc.name] = enc.transform(obs) + + # `to_numpy()` avoids copying the numpy array data + obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) + + if not self.return_sparse_X: + X_tensor = torch.from_numpy(X.todense()) + else: + coo = X.tocoo() + + X_tensor = torch.sparse_coo_tensor( + # Note: The `np.array` seems unnecessary, but PyTorch warns bare array + # is "extremely slow" + indices=torch.from_numpy(np.array([coo.row, coo.col])), + values=coo.data, + size=coo.shape, + ) + + if self.batch_size == 1: + X_tensor = X_tensor[0] + obs_tensor = obs_tensor[0] + + return X_tensor, obs_tensor + + def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: + """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size + + batch whose size may contain fewer rows than the requested ``batch_size``. This can happen + when the remaining rows in the current SOMA chunk are fewer than the requested + ``batch_size``. + """ + if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)): + # GC memory from previous soma_chunk + self.soma_chunk = None + mem_info = run_gc() + self.max_process_mem_usage_bytes = max( + self.max_process_mem_usage_bytes, mem_info[0][0].uss + ) + + self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) + self.stats += self.soma_chunk.stats + self.i = 0 + + pytorch_logger.debug(f"Retrieved SOMA chunk totals: {self.stats}") + + obs_batch = self.soma_chunk.obs + X_batch = self.soma_chunk.X + + safe_batch_size = min(batch_size, len(obs_batch) - self.i) + slice_ = slice(self.i, self.i + safe_batch_size) + assert slice_.stop <= obs_batch.shape[0] + + obs_rows = obs_batch.iloc[slice_] + assert obs_rows.index.is_unique + assert safe_batch_size == obs_rows.shape[0] + + X_csr_scipy = X_batch[slice_] + assert obs_rows.shape[0] == X_csr_scipy.shape[0] + + self.i += safe_batch_size + + return obs_rows, X_csr_scipy + + +class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore + r"""An :class:`torchdata.datapipes.iter.IterDataPipe` that reads ``obs`` and ``X`` data from a + + :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` + axes. Provides an iterator over these data when the object is passed to Python's built-in + ``iter`` function. + + >>> for batch in iter(ExperimentDataPipe(...)): + X_batch, y_batch = batch + + The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are + returned in each iteration. If the ``batch_size`` is 1, then each Tensor will have rank 1: + + >>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data + tensor([2415, 0, 0], dtype=torch.int64)) # obs data, encoded + + For larger ``batch_size`` values, the returned Tensors will have rank 2: + + >>> DataLoader(..., batch_size=3, ...): + (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch + [0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), + tensor([[2415, 0, 0], # obs batch + [2416, 0, 4], + [2417, 0, 3]], dtype=torch.int64)) + + The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or + sparse :class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, + this will reduce memory usage. + + The ``obs_column_names`` parameter determines the data columns that are returned in the + ``obs`` Tensor. The first element is always the ``soma_joinid`` of the ``obs`` + :class:`pandas.DataFrame` (or, equivalently, the ``soma_dim_0`` of the ``X`` matrix). + The remaining elements are the ``obs`` columns specified by ``obs_column_names``, + and string-typed columns are encoded as integer values. If needed, these values can be decoded + by obtaining the encoder for a given ``obs`` column name and calling its ``inverse_transform`` + method: + + >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) + + Lifecycle: + experimental + """ + + _initialized: bool + + _obs_joinids: npt.NDArray[np.int64] | None + + _var_joinids: npt.NDArray[np.int64] | None + + _encoders: list[Encoder] + + _stats: Stats + + _shuffle_rng: Generator | None + + # TODO: Consider adding another convenience method wrapper to construct this object whose + # signature is more closely aligned with get_anndata() params + # (i.e. "exploded" AxisQuery params). + def __init__( + self, + experiment: soma.Experiment, + measurement_name: str = "RNA", + X_name: str = "raw", + obs_query: soma.AxisQuery | None = None, + var_query: soma.AxisQuery | None = None, + obs_column_names: Sequence[str] = (), + batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, + return_sparse_X: bool = False, + soma_chunk_size: int | None = 64, + use_eager_fetch: bool = True, + encoders: list[Encoder] | None = None, + shuffle_chunk_count: int | None = 2000, + ) -> None: + r"""Construct a new ``ExperimentDataPipe``. + + Args: + experiment: + The :class:`tiledbsoma.Experiment` from which to read data. + measurement_name: + The name of the :class:`tiledbsoma.Measurement` to read. Defaults to ``"RNA"``. + X_name: + The name of the X layer to read. Defaults to ``"raw"``. + obs_query: + The query used to filter along the ``obs`` axis. If not specified, all ``obs`` and + ``X`` data will be returned, which can be very large. + var_query: + The query used to filter along the ``var`` axis. If not specified, all ``var`` + columns (genes/features) will be returned. + obs_column_names: + The names of the ``obs`` columns to return. The ``soma_joinid`` index "column" does + not need to be specified and will always be returned. If not specified, only the + ``soma_joinid`` will be returned. + batch_size: + The number of rows of ``obs`` and ``X`` data to return in each iteration. Defaults + to ``1``. A value of ``1`` will result in :class:`torch.Tensor` of rank 1 being + returns (a single row); larger values will result in :class:`torch.Tensor`\ s of + rank 2 (multiple rows). + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. + For performance reasons, shuffling is not performed globally across all rows, but + rather in chunks. More specifically, we select ``shuffle_chunk_count`` + non-contiguous chunks across all the observations + in the query, concatenate the chunks and shuffle the associated observations. + The randomness of the shuffling is therefore determined by the + (``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have + been determined to yield a good trade-off between randomness and performance. + Further tuning may be required for different type of models. Note that memory usage + is correlated to the product ``soma_chunk_size * shuffle_chunk_count``. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be + specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure + data partitions are disjoint across worker processes. + return_sparse_X: + Controls whether the ``X`` data is returned as a dense or sparse + :class:`torch.Tensor`. As ``X`` data is very sparse, setting this to ``True`` will + reduce memory usage, if the model supports use of sparse :class:`torch.Tensor`\ s. + Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still experimental + in PyTorch. + soma_chunk_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This + impacts two aspects of this class's behavior: 1) The maximum memory utilization, + with larger values providing better read performance, but also requiring more + memory; 2) The granularity of the global shuffling step (see ``shuffle`` parameter + for details). The default value of 64 works well in conjunction with the default + ``shuffle_chunk_count`` value. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously + fetched SOMA chunk is made available for processing via the iterator. This allows + network (or filesystem) requests to be made in parallel with client-side processing + of the SOMA data, potentially improving overall performance at the cost of + doubling memory utilization. Defaults to ``True``. + shuffle_chunk_count: + The number of contiguous blocks (chunks) of rows sampled to then concatenate + and shuffle. Larger numbers correspond to more randomness per training batch. + If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``. + encoders: + Specify custom encoders to be used. If not specified, a LabelEncoder will be + created and used for each column in ``obs_column_names``. If specified, only + columns for which an encoder has been registered will be returned in the + ``obs`` tensor. + + Lifecycle: + experimental + """ + self.exp_uri = experiment.uri + self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region") + self.measurement_name = measurement_name + self.layer_name = X_name + self.obs_query = obs_query + self.var_query = var_query + self.obs_column_names = obs_column_names + self.batch_size = batch_size + self.return_sparse_X = return_sparse_X + self.soma_chunk_size = soma_chunk_size + self.use_eager_fetch = use_eager_fetch + self._stats = Stats() + self._custom_encoders = encoders + self._encoders = [] + self._obs_joinids = None + self._var_joinids = None + self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None + self._shuffle_rng = np.random.default_rng(seed) if shuffle else None + self._initialized = False + + if "soma_joinid" not in self.obs_column_names: + self.obs_column_names = ["soma_joinid", *self.obs_column_names] + + def _init(self) -> None: + if self._initialized: + return + + pytorch_logger.debug("Initializing ExperimentDataPipe") + + with _open_experiment(self.exp_uri, self.aws_region) as exp: + query = exp.axis_query( + measurement_name=self.measurement_name, + obs_query=self.obs_query, + var_query=self.var_query, + ) + + # The to_numpy() call is a workaround for a possible bug in TileDB-SOMA: + # https://github.com/single-cell-data/TileDB-SOMA/issues/1456 + self._obs_joinids = query.obs_joinids().to_numpy() + self._var_joinids = query.var_joinids().to_numpy() + + self._encoders = self._build_obs_encoders(query) + + self._initialized = True + + @staticmethod + def _subset_ids_to_partition( + ids_chunked: list[npt.NDArray[np.int64]], + partition_index: int, + num_partitions: int, + ) -> list[npt.NDArray[np.int64]]: + """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), + + based upon the current process's distributed rank and world size. + """ + # subset to a single partition + # typing does not reflect that is actually a list of 2D NDArrays + partition_indices = np.array_split(range(len(ids_chunked)), num_partitions) + partition = [ids_chunked[i] for i in partition_indices[partition_index]] + + if pytorch_logger.isEnabledFor(logging.DEBUG) and len(partition) > 0: + pytorch_logger.debug( + f"Process {os.getpid()} handling partition {partition_index + 1} " + f"of {num_partitions}, partition_size={sum([len(chunk) for chunk in partition])}" + ) + + return partition + + @staticmethod + def _compute_partitions( + loader_partition: int, + loader_partitions: int, + dist_partition: int, + num_dist_partitions: int, + ) -> tuple[int, int]: + # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload + total_partitions = num_dist_partitions * loader_partitions + partition = dist_partition * loader_partitions + loader_partition + return partition, total_partitions + + def __iter__(self) -> Iterator[ObsAndXDatum]: + self._init() + assert self._obs_joinids is not None + assert self._var_joinids is not None + + if self.soma_chunk_size is None: + # set soma_chunk_size to utilize ~1 GiB of RAM per SOMA chunk; assumes 95% X data + # sparsity, 8 bytes for the X value and 8 bytes for the sparse matrix indices, + # and a 100% working memory overhead (2x). + X_row_memory_size = 0.05 * len(self._var_joinids) * 8 * 3 * 2 + self.soma_chunk_size = int((1 * 1024**3) / X_row_memory_size) + pytorch_logger.debug(f"Using {self.soma_chunk_size=}") + + if ( + self.return_sparse_X + and torch.utils.data.get_worker_info() + and torch.utils.data.get_worker_info().num_workers > 0 + ): + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) + + # chunk the obs joinids into batches of size soma_chunk_size + obs_joinids_chunked = self._chunk_ids(self._obs_joinids, self.soma_chunk_size) + + # globally shuffle the chunks, if requested + if self._shuffle_rng: + self._shuffle_rng.shuffle(obs_joinids_chunked) + + # subset to a single partition, as needed for distributed training and multi-processing + # data loading + worker_info = torch.utils.data.get_worker_info() + partition, partitions = self._compute_partitions( + loader_partition=worker_info.id if worker_info else 0, + loader_partitions=worker_info.num_workers if worker_info else 1, + dist_partition=dist.get_rank() if dist.is_initialized() else 0, + num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, + ) + obs_joinids_chunked_partition: list[npt.NDArray[np.int64]] = self._subset_ids_to_partition( + obs_joinids_chunked, partition, partitions + ) + + with _open_experiment(self.exp_uri, self.aws_region) as exp: + obs_and_x_iter = _ObsAndXIterator( + obs=exp.obs, + X=exp.ms[self.measurement_name].X[self.layer_name], + obs_column_names=self.obs_column_names, + obs_joinids_chunked=obs_joinids_chunked_partition, + var_joinids=self._var_joinids, + batch_size=self.batch_size, + encoders=self._encoders, + stats=self._stats, + return_sparse_X=self.return_sparse_X, + use_eager_fetch=self.use_eager_fetch, + shuffle_rng=self._shuffle_rng, + shuffle_chunk_count=self._shuffle_chunk_count, + ) + + yield from obs_and_x_iter + + pytorch_logger.debug( + "max process memory usage=" + f"{obs_and_x_iter.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" + ) + + @staticmethod + def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> list[npt.NDArray[np.int64]]: + num_chunks = max(1, ceil(len(ids) / chunk_size)) + pytorch_logger.debug( + f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}" + ) + return np.array_split(ids, num_chunks) + + def __len__(self) -> int: + self._init() + assert self._obs_joinids is not None + + return len(self._obs_joinids) + + def __getitem__(self, index: int) -> ObsAndXDatum: + raise NotImplementedError("IterDataPipe can only be iterated") + + def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> list[Encoder]: + pytorch_logger.debug("Initializing encoders") + + encoders = [] + obs = query.obs(column_names=self.obs_column_names).concat().to_pandas() + + if self._custom_encoders: + # Register all the custom encoders with obs + for enc in self._custom_encoders: + enc.register(obs) + encoders.append(enc) + else: + # Create one DefaultEncoder for each column, and register it with obs + for col in self.obs_column_names: + if obs[col].dtype in [object]: + enc = DefaultEncoder(col) + enc.register(obs) + encoders.append(enc) + + return encoders + + # TODO: This does not work in multiprocessing mode, as child process's stats are not collected + def stats(self) -> Stats: + """Get data loading stats for this + + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + + Returns + ------- + The :class:`cellxgene_census.experimental.ml.pytorch.Stats` object for this + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + + Lifecycle: + experimental + """ + return self._stats + + @property + def shape(self) -> tuple[int, int]: + """Get the shape of the data that will be returned by this + + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + This is the number of obs (cell) and var (feature) counts in the returned data. If used in + multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` + instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the partition of the data assigned to the active process. + + Returns + ------- + A 2-tuple of ``int``s, for obs and var counts, respectively. + + Lifecycle: + experimental + """ + self._init() + assert self._obs_joinids is not None + assert self._var_joinids is not None + + return len(self._obs_joinids), len(self._var_joinids) + + @property + def obs_encoders(self) -> Encoders: + """Returns a dictionary of :class:`sklearn.preprocessing.LabelEncoder` objects, keyed on + + ``obs`` column names, which were used to encode the ``obs`` column values. + + These encoders can be used to decode the encoded values as follows: + + >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) + + Returns + ------- + A ``dict[str, LabelEncoder]``, mapping column names to :class:`sklearn.preprocessing. + LabelEncoder` objects. + """ + self._init() + assert self._encoders is not None + + return {enc.name: enc for enc in self._encoders} + + +# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling +def _collate_noop(x: Any) -> Any: + return x + + +# TODO: Move into somacore.ExperimentAxisQuery +def experiment_dataloader( + datapipe: pipes.IterDataPipe, + num_workers: int = 0, + **dataloader_kwargs: Any, +) -> DataLoader: + """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely + + instantiate a :class:`torch.utils.data.DataLoader` that works with + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`, since some of the + :class:`torch.utils.data.DataLoader` constructor parameters are not applicable when using a + :class:`torchdata.datapipes.iter.IterDataPipe` (``shuffle``, ``batch_size``, ``sampler``, + ``batch_sampler``,``collate_fn``). + + Args: + datapipe: + An :class:`torchdata.datapipes.iter.IterDataPipe`, which can be an + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe` or any other + :class:`torchdata.datapipes.iter.IterDataPipe` that has been chained to the + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + num_workers: + Number of worker processes to use for data loading. If ``0``, data will be loaded in + the main process. + **dataloader_kwargs: + Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` + constructor, except for ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, + and ``collate_fn``, which are not supported when using + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + + Returns + ------- + A :class:`torch.utils.data.DataLoader`. + + Raises + ------ + ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, + or ``collate_fn`` params are passed as keyword arguments. + + Lifecycle: + experimental + """ + unsupported_dataloader_args = [ + "shuffle", + "batch_size", + "sampler", + "batch_sampler", + "collate_fn", + ] + if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): + raise ValueError( + f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported" + ) + + if num_workers > 0: + _init_multiprocessing() + + return DataLoader( + datapipe, + batch_size=None, # batching is handled by our ExperimentDataPipe + num_workers=num_workers, + # avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches + collate_fn=_collate_noop, + # shuffling is handled by our ExperimentDataPipe + shuffle=False, + **dataloader_kwargs, + ) + + +def _init_multiprocessing() -> None: + """Ensures use of "spawn" for starting child processes with multiprocessing. + + Forked processes are known to be problematic: + https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks + Also, CUDA does not support forked child processes: + https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + + """ + torch.multiprocessing.set_start_method("fork", force=True) + orig_start_method = torch.multiprocessing.get_start_method() + if orig_start_method != "spawn": + if orig_start_method: + pytorch_logger.warning( + "switching torch multiprocessing start method from " + f'"{torch.multiprocessing.get_start_method()}" to "spawn"' + ) + torch.multiprocessing.set_start_method("spawn", force=True) + + +class BatchEncoder(Encoder): + """Concatenates and encodes several columns.""" + + def __init__(self, cols: list[str]): + self.cols = cols + from sklearn.preprocessing import LabelEncoder + + self._encoder = LabelEncoder() + + def transform(self, df: pd.DataFrame): + import functools + + arr = functools.reduce(lambda a, b: a + b, [df[c].astype(str) for c in self.cols]) + return self._encoder.transform(arr) + + def register(self, obs: pd.DataFrame): + import functools + + arr = functools.reduce(lambda a, b: a + b, [obs[c].astype(str) for c in self.cols]) + self._encoder.fit(arr.unique()) + + @property + def name(self) -> str: + return "batch" + + @property + def classes_(self): + return self._encoder.classes_ + + +class CensusSCVIDataModule(LightningDataModule): + """Lightning data module for CxG Census. + + Parameters + ---------- + *args + Positional arguments passed to + :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + batch_keys + List of obs column names concatenated to form the batch column. + train_size + Fraction of data to use for training. + split_seed + Seed for data split. + dataloader_kwargs + Keyword arguments passed into + :func:`~cellxgene_census.experimental.ml.pytorch.experiment_dataloader`. + **kwargs + Additional keyword arguments passed into + :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. Must not include + ``obs_column_names``. + """ + + _TRAIN_KEY = "train" + _VALIDATION_KEY = "validation" + + def __init__( + self, + *args, + batch_keys: list[str] | None = None, + train_size: float | None = None, + split_seed: int | None = None, + dataloader_kwargs: dict[str, any] | None = None, + **kwargs, + ): + super().__init__() + self.datapipe_args = args + self.datapipe_kwargs = kwargs + self.batch_keys = batch_keys + self.train_size = train_size + self.split_seed = split_seed + self.dataloader_kwargs = dataloader_kwargs or {} + + @property + def batch_keys(self) -> list[str]: + """List of obs column names concatenated to form the batch column.""" + if not hasattr(self, "_batch_keys"): + raise AttributeError("`batch_keys` not set.") + return self._batch_keys + + @batch_keys.setter + def batch_keys(self, value: list[str] | None): + if value is None or not isinstance(value, list): + raise ValueError("`batch_keys` must be a list of strings.") + self._batch_keys = value + + @property + def obs_column_names(self) -> list[str]: + """Passed to :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`.""" + if hasattr(self, "_obs_column_names"): + return self._obs_column_names + + obs_column_names = [] + if self.batch_keys is not None: + obs_column_names.extend(self.batch_keys) + + self._obs_column_names = obs_column_names + return self._obs_column_names + + @property + def split_seed(self) -> int: + """Seed for data split.""" + if not hasattr(self, "_split_seed"): + raise AttributeError("`split_seed` not set.") + return self._split_seed + + @split_seed.setter + def split_seed(self, value: int | None): + if value is not None and not isinstance(value, int): + raise ValueError("`split_seed` must be an integer.") + self._split_seed = value or 0 + + @property + def train_size(self) -> float: + """Fraction of data to use for training.""" + if not hasattr(self, "_train_size"): + raise AttributeError("`train_size` not set.") + return self._train_size + + @train_size.setter + def train_size(self, value: float | None): + if value is not None and not isinstance(value, float): + raise ValueError("`train_size` must be a float.") + elif value is not None and (value < 0.0 or value > 1.0): + raise ValueError("`train_size` must be between 0.0 and 1.0.") + self._train_size = value or 1.0 + + @property + def validation_size(self) -> float: + """Fraction of data to use for validation.""" + if not hasattr(self, "_train_size"): + raise AttributeError("`validation_size` not available.") + return 1.0 - self.train_size + + @property + def weights(self) -> dict[str, float]: + """Passed to :meth:`~cellxgene_census.experimental.ml.ExperimentDataPipe.random_split`.""" + if not hasattr(self, "_weights"): + self._weights = {self._TRAIN_KEY: self.train_size} + if self.validation_size > 0.0: + self._weights[self._VALIDATION_KEY] = self.validation_size + return self._weights + + @property + def datapipe(self) -> ExperimentDataPipe: + """Experiment data pipe.""" + if not hasattr(self, "_datapipe"): + encoder = BatchEncoder(self.obs_column_names) + self._datapipe = ExperimentDataPipe( + *self.datapipe_args, + obs_column_names=self.obs_column_names, + encoders=[encoder], + **self.datapipe_kwargs, + ) + return self._datapipe + + def setup(self, stage: str | None = None): + """Set up the train and validation data pipes.""" + datapipes = self.datapipe.random_split(weights=self.weights, seed=self.split_seed) + self._train_datapipe = datapipes[0] + if self.validation_size > 0.0: + self._validation_datapipe = datapipes[1] + else: + self._validation_datapipe = None + + def train_dataloader(self): + """Training data loader.""" + return experiment_dataloader(self._train_datapipe, **self.dataloader_kwargs) + + def val_dataloader(self): + """Validation data loader.""" + if self._validation_datapipe is not None: + return experiment_dataloader(self._validation_datapipe, **self.dataloader_kwargs) + + @property + def n_obs(self) -> int: + """Number of observations in the query. + + Necessary in scvi-tools to compute a heuristic of ``max_epochs``. + """ + return self.datapipe.shape[0] + + @property + def n_vars(self) -> int: + """Number of features in the query. + + Necessary in scvi-tools to initialize the actual layers in the model. + + """ + return self.datapipe.shape[1] + + @property + def n_batch(self) -> int: + """ + Number of unique batches (after concatenation of ``batch_keys``). Necessary in scvi-tools + + so that the model knows how to one-hot encode batches. + + """ + return self.get_n_classes("batch") + + def get_n_classes(self, key: str) -> int: + """Return the number of classes for a given obs column.""" + return len(self.datapipe.obs_encoders[key].classes_) + + def on_before_batch_transfer( + self, + batch: tuple[torch.Tensor, torch.Tensor], + dataloader_idx: int, + ) -> dict[str, torch.Tensor | None]: + """Format the datapipe output with registry keys for scvi-tools.""" + X, obs = batch + + X_KEY: str = "X" + BATCH_KEY: str = "batch" + LABELS_KEY: str = "labels" + + return { + X_KEY: X, + BATCH_KEY: obs, + LABELS_KEY: None, + } diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 95c88c3541..e9bc5ec694 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -141,6 +141,7 @@ def __init__( ) if self._module_init_on_train: + # Here we need to adjust given the new custom data loader self.module = None warnings.warn( "Model was initialized without `adata`. The module will be initialized when " diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py new file mode 100644 index 0000000000..9d48db4f9f --- /dev/null +++ b/tests/dataloaders/test_custom_dataloader.py @@ -0,0 +1 @@ +from __future__ import annotations From cc72b05f27f75f349b5946aeaca5e30a10bdcb21 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 30 Jul 2024 16:18:49 +0300 Subject: [PATCH 02/53] added some fixes to the custom dataloader stuff --- src/scvi/data/_manager.py | 49 +++++ src/scvi/model/_scvi.py | 66 +++++- src/scvi/model/base/_base_model.py | 17 ++ src/scvi/model/base/_training_mixin.py | 4 + tests/dataloaders/test_custom_dataloader.py | 63 ++++++ tests/dataloaders/test_custom_dataloader2.py | 213 +++++++++++++++++++ tests/model/test_scvi.py | 1 + 7 files changed, 412 insertions(+), 1 deletion(-) create mode 100644 tests/dataloaders/test_custom_dataloader2.py diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 10d0219041..88b7be838b 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -192,6 +192,55 @@ def register_fields( self._assign_uuid() self._assign_most_recent_manager_uuid() + def register_data_module_fields( + self, + datamodule, + source_registry: dict | None = None, + **transfer_kwargs, + ): + """Registers each field associated with this instance with the AnnData object. + + Either registers or transfers the setup from `source_setup_dict` if passed in. + Sets ``self.adata``. + + Parameters + ---------- + adata + AnnData object to be registered. + source_registry + Registry created after registering an AnnData using an + :class:`~scvi.data.AnnDataManager` object. + transfer_kwargs + Additional keywords which modify transfer behavior. Only applicable if + ``source_registry`` is set. + """ + if self.adata is not None: + raise AssertionError("Existing AnnData object registered with this Manager instance.") + + if source_registry is None and transfer_kwargs: + raise TypeError( + f"register_fields() got unexpected keyword arguments {transfer_kwargs} passed " + "without a source_registry." + ) + + self._validate_anndata_object(datamodule) + + for field in self.fields: + self._add_field( + field=field, + adata=datamodule, + source_registry=source_registry, + **transfer_kwargs, + ) + + # Save arguments for register_fields. + self._source_registry = deepcopy(source_registry) + self._transfer_kwargs = deepcopy(transfer_kwargs) + + self.adata = datamodule + self._assign_uuid() + self._assign_most_recent_manager_uuid() + def _add_field( self, field: AnnDataField, diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index e9bc5ec694..4f9f645ad7 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -140,8 +140,10 @@ def __init__( f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) + # in the next part we need to construct the same module no mather the way + # dataloader was given if self._module_init_on_train: - # Here we need to adjust given the new custom data loader + # Here we need to adjust given the new custom data loader like CZI case self.module = None warnings.warn( "Model was initialized without `adata`. The module will be initialized when " @@ -225,6 +227,68 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) + # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() + # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] + # pprint(adata_manager.registry) + + @classmethod + @setup_anndata_dsp.dedent + def setup_datamodule( + cls, + datamodule, + layer: str | None = None, + batch_key: str | None = None, + labels_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + **kwargs, + ): + """%(summary)s. + + Parameters + ---------- + %(param_adata)s + %(param_layer)s + %(param_batch_key)s + %(param_labels_key)s + %(param_size_factor_key)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + """ + setup_method_args = cls._get_setup_method_args(**locals()) + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), + ] + # register new fields if the adata is minified + # adata_minify_type = _get_adata_minify_type(adata) + # if adata_minify_type is not None: + # anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager.registry["setup_method_name"] = "setup_datamodule" + adata_manager.registry["setup_args"]["layer"] = datamodule.datapipe.layer_name + adata_manager.registry["setup_args"]["batch_key"] = datamodule.batch_keys + adata_manager.registry["setup_args"]["labels_key"] + adata_manager.registry["setup_args"]["batch_key"] + adata_manager.registry["setup_args"]["batch_key"] + adata_manager.registry["setup_args"]["batch_key"] + # datamodule._datapipe.obs_column_names + # datamodule._datapipe.obs_encoders + # adata_manager.register_fields(adata, **kwargs) + # how to etract the information we need from the datamodule + adata_manager.register_data_module_fields( + datamodule, **kwargs + ) # here we need a new function for data module + + cls.register_manager(adata_manager) + # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() + # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] + # pprint(adata_manager.registry) @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index e991da3eb3..aaef7f84a5 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -815,6 +815,23 @@ def setup_anndata( on a model-specific instance of :class:`~scvi.data.AnnDataManager`. """ + @classmethod + @abstractmethod + @setup_anndata_dsp.dedent + def setup_datamodule( + cls, + datamodule, + *args, + **kwargs, + ): + """%(summary)s. + + Each model class deriving from this class provides parameters to this method + according to its needs. To operate correctly with the model initialization, + the implementation must call :meth:`~scvi.model.base.BaseModelClass.register_manager` + on a model-specific instance of :class:`~scvi.data.AnnDataManager`. + """ + @staticmethod def view_setup_args(dir_path: str, prefix: str | None = None) -> None: """Print args used to setup a saved model. diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index de3efd9fcb..86da6b5019 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -102,6 +102,7 @@ def train( ) if datamodule is None: + # In the general case we enter here datasplitter_kwargs = datasplitter_kwargs or {} datamodule = self._data_splitter_cls( self.adata_manager, @@ -114,6 +115,7 @@ def train( **datasplitter_kwargs, ) elif self.module is None: + # in CZI case we enter here self.module = self._module_cls( datamodule.n_vars, n_batch=datamodule.n_batch, @@ -122,6 +124,8 @@ def train( n_cats_per_cov=getattr(datamodule, "n_cats_per_cov", None), **self._module_kwargs, ) + # after either of the cases we should be here with the same self.module + # and same datamodule plan_kwargs = plan_kwargs or {} training_plan = self._training_plan_cls(self.module, **plan_kwargs) diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 9d48db4f9f..c80cb843b5 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1 +1,64 @@ from __future__ import annotations + +import os + +import numpy as np +import scanpy as sc + +import scvi +from scvi.data import _constants, synthetic_iid +from scvi.model import SCVI + +# We will now create the SCVI model object: +# Its parameters: +n_layers = 1 +n_latent = 10 +batch_size = 1024 +train_size = 0.9 +max_epochs = 1 + + +# COMAPRE TO THE ORIGINAL METHOD!!! - use the same data!!! +# We first create a registry using the orignal way of anndata in order to compare and add +# what is missing +adata = synthetic_iid() +adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) +SCVI.setup_anndata( + adata, + batch_key="batch", + labels_key="labels", + size_factor_key="size_factor", +) +# +model_orig = SCVI(adata, n_latent=n_latent) +model_orig.train(1, check_val_every_n_epoch=1, train_size=0.5) + +# Saving the model +save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory() +model_dir = os.path.join(save_dir, "scvi_orig_model") +model_orig.save(model_dir, overwrite=True) + +# Loading the model (just as a compariosn) +model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata) + +# Obtaining model outputs +SCVI_LATENT_KEY = "X_scVI" +latent = model_orig.get_latent_representation() +adata.obsm[SCVI_LATENT_KEY] = latent +# latent.shape + +# You can see all necessary entries and the structure at +adata_manager = model_orig.adata_manager +model_orig.view_anndata_setup(hide_state_registries=True) +# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() +adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] + +# Plot UMAP and save the figure for later check +sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") +sc.tl.umap(adata, neighbors_key="scvi") +sc.pl.umap(adata, color="dataset_id", title="SCVI") + +# Now return and add all the registry stuff that we will need + +# Now add the missing stuff from the current CZI implemenation in order for us to have the exact +# same steps like the original way (except than setup_anndata) diff --git a/tests/dataloaders/test_custom_dataloader2.py b/tests/dataloaders/test_custom_dataloader2.py new file mode 100644 index 0000000000..5b741ebfad --- /dev/null +++ b/tests/dataloaders/test_custom_dataloader2.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import os + +import cellxgene_census +import pandas as pd +import scanpy as sc +import tiledbsoma as soma +import torch +from cellxgene_census.experimental.pp import highly_variable_genes + +import scvi +from scvi.dataloaders._custom_dataloader import CensusSCVIDataModule, experiment_dataloader +from scvi.model import SCVI + +# We will now create the SCVI model object: +# Its parameters: +n_layers = 1 +n_latent = 10 +batch_size = 1024 +train_size = 0.9 +max_epochs = 1 + +# We have to create a registry without setup_anndata that contains the same elements +# The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK +# need to pass here new object of registry taht contains everything we will need + + +# First lets see CELLXGENE example using pytorch loaders implemented now in our repo +census = cellxgene_census.open_soma(census_version="stable") +experiment_name = "mus_musculus" +obs_value_filter = 'is_primary_data == True and tissue_general in ["spleen"] and nnz >= 300' +top_n_hvg = 8000 +hvg_batch = ["assay", "suspension_type"] +# THIS WILL TAKE FEW MINUTES TO RUN! +query = census["census_data"][experiment_name].axis_query( + measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter) +) +hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch) +hv = hvgs_df.highly_variable +hv_idx = hv[hv].index + +# Now load the custom data module CZI did that now exists in our db +# (and we will later want to elaborate with more info from our original anndata registry) +# This thing is done by the user in any form they want +datamodule = CensusSCVIDataModule( + census["census_data"][experiment_name], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + batch_size=1024, + shuffle=True, + batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +) +# This is a new func to implement +SCVI.setup_datamodule(datamodule) +# +model = SCVI(n_layers=n_layers, n_latent=n_latent, gene_likelihood="nb", encode_covariates=False) + + +# The CZI data module is a refined data module while SCVI is a lighting datamodule +# Altough this is only 1 epoch it will take few mins on local machine +model.train( + datamodule=datamodule, + max_epochs=max_epochs, + batch_size=batch_size, + train_size=train_size, + early_stopping=False, +) + +# We can now save the trained model. As of the current writing date (June 2024), +# scvi-tools doesn't support saving a model that wasn't generated through an AnnData loader, +# so we'll use some custom code: +model_state_dict = model.module.state_dict() +var_names = hv_idx.to_numpy() +user_attributes = model._get_user_attributes() +user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} + +user_attributes.update( + { + "n_batch": datamodule.n_batch, + "n_extra_categorical_covs": 0, + "n_extra_continuous_covs": 0, + "n_labels": 1, + "n_vars": datamodule.n_vars, + } +) + +with open("model.pt", "wb") as f: + torch.save( + { + "model_state_dict": model_state_dict, + "var_names": var_names, + "attr_dict": user_attributes, + }, + f, + ) + +# Saving the model the original way +save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory() +model_dir = os.path.join(save_dir, "scvi_czi_model") +model.save(model_dir, overwrite=True) + + +# We will now load the model back and use it to generate cell embeddings (the latent space), +# which can then be used for further analysis. Note that we still need to use some custom code for +# loading the model, which includes loading the parameters from the `attr_dict` node stored in +# the model. +with open("model.pt", "rb") as f: + torch_model = torch.load(f) + + adict = torch_model["attr_dict"] + params = adict["init_params_"]["non_kwargs"] + + n_batch = adict["n_batch"] + n_extra_categorical_covs = adict["n_extra_categorical_covs"] + n_extra_continuous_covs = adict["n_extra_continuous_covs"] + n_labels = adict["n_labels"] + n_vars = adict["n_vars"] + + latent_distribution = params["latent_distribution"] + dispersion = params["dispersion"] + n_hidden = params["n_hidden"] + dropout_rate = params["dropout_rate"] + gene_likelihood = params["gene_likelihood"] + + model = scvi.model.SCVI( + n_layers=params["n_layers"], + n_latent=params["n_latent"], + gene_likelihood=params["gene_likelihood"], + encode_covariates=False, + ) + + module = model._module_cls( + n_input=n_vars, + n_batch=n_batch, + n_labels=n_labels, + n_continuous_cov=n_extra_continuous_covs, + n_cats_per_cov=None, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + dispersion=dispersion, + gene_likelihood=gene_likelihood, + latent_distribution=latent_distribution, + ) + model.module = module + + model.module.load_state_dict(torch_model["model_state_dict"]) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model.to_device(device) + model.module.eval() + model.is_trained = True + +# We will now generate the cell embeddings for this model, using the `get_latent_representation` +# function available in scvi-tools. +# We can use another instance of the `ExperimentDataPipe` for the forward pass, so we don't need +# to load the whole dataset in memory. + +# Needs to have shuffle=False for inference +datamodule_inference = CensusSCVIDataModule( + census["census_data"][experiment_name], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + batch_size=1024, + shuffle=False, + soma_chunk_size=50_000, + batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +) + +# We can simply feed the datapipe to `get_latent_representation` to obtain the embeddings - +# will take a while +datapipe = datamodule_inference.datapipe +dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) +mapped_dataloader = ( + datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader +) +latent = model.get_latent_representation(dataloader=mapped_dataloader) +emb_idx = datapipe._obs_joinids + +# We will now take a look at the UMAP for the generated embedding +# (will be later comapred to what we got) +adata = cellxgene_census.get_anndata( + census, + organism=experiment_name, + obs_value_filter=obs_value_filter, +) +obs_soma_joinids = adata.obs["soma_joinid"] +obs_indexer = pd.Index(emb_idx) +idx = obs_indexer.get_indexer(obs_soma_joinids) +# Reindexing is necessary to ensure that the cells in the embedding match the +# ones in the anndata object. +adata.obsm["scvi"] = latent[idx] + +# Plot UMAP and save the figure for later check +sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") +sc.tl.umap(adata, neighbors_key="scvi") +sc.pl.umap(adata, color="dataset_id", title="SCVI") + + +# Now return and add all the registry stuff that we will need + + +# Now add the missing stuff from the current CZI implemenation in order for us to have the exact +# same steps like the original way (except than setup_anndata) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index ba3976f64c..1826f55da1 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -930,6 +930,7 @@ def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): model.train(datamodule=datamodule) model = SCVI(adata, n_latent=5) + # Add an example for external custom dataloader? assert not model._module_init_on_train assert model.module is not None assert hasattr(model, "adata") From 46048e3064c66ad52f744638ce088d9512495de1 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Tue, 30 Jul 2024 12:20:51 -0700 Subject: [PATCH 03/53] Some suggestions --- src/scvi/data/_manager.py | 49 --------------------------------------- src/scvi/model/_scvi.py | 25 +++++++++++++------- tests/model/test_scvi.py | 18 ++++++++++++++ 3 files changed, 34 insertions(+), 58 deletions(-) diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 88b7be838b..10d0219041 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -192,55 +192,6 @@ def register_fields( self._assign_uuid() self._assign_most_recent_manager_uuid() - def register_data_module_fields( - self, - datamodule, - source_registry: dict | None = None, - **transfer_kwargs, - ): - """Registers each field associated with this instance with the AnnData object. - - Either registers or transfers the setup from `source_setup_dict` if passed in. - Sets ``self.adata``. - - Parameters - ---------- - adata - AnnData object to be registered. - source_registry - Registry created after registering an AnnData using an - :class:`~scvi.data.AnnDataManager` object. - transfer_kwargs - Additional keywords which modify transfer behavior. Only applicable if - ``source_registry`` is set. - """ - if self.adata is not None: - raise AssertionError("Existing AnnData object registered with this Manager instance.") - - if source_registry is None and transfer_kwargs: - raise TypeError( - f"register_fields() got unexpected keyword arguments {transfer_kwargs} passed " - "without a source_registry." - ) - - self._validate_anndata_object(datamodule) - - for field in self.fields: - self._add_field( - field=field, - adata=datamodule, - source_registry=source_registry, - **transfer_kwargs, - ) - - # Save arguments for register_fields. - self._source_registry = deepcopy(source_registry) - self._transfer_kwargs = deepcopy(transfer_kwargs) - - self.adata = datamodule - self._assign_uuid() - self._assign_most_recent_manager_uuid() - def _add_field( self, field: AnnDataField, diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 4f9f645ad7..9ba332f91f 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -256,6 +256,8 @@ def setup_datamodule( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ + + # Remove these lines. We don't need an adata_manager. setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), @@ -271,20 +273,25 @@ def setup_datamodule( # anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.registry["setup_method_name"] = "setup_datamodule" + + """ + ORI check here the elements are used in the datamodule. + We can stick to their solution for now. But we should check for all setup things whether + they are present in the datamodule. + These checks can adfterwards go to a new class. But implement them here. And ignore all adata things. + We just want to have the same dictionary + """ + if datamodule.get_batch_keys() is not None: + adata_manager.registry["setup_args"]["batch_key"] = datamodule.get_batch_keys() + if datamodule.get_labels_keys() is not None: + adata_manager.registry["setup_args"]["labels_key"] = datamodule.get_labels_keys() adata_manager.registry["setup_args"]["layer"] = datamodule.datapipe.layer_name - adata_manager.registry["setup_args"]["batch_key"] = datamodule.batch_keys - adata_manager.registry["setup_args"]["labels_key"] - adata_manager.registry["setup_args"]["batch_key"] - adata_manager.registry["setup_args"]["batch_key"] - adata_manager.registry["setup_args"]["batch_key"] - # datamodule._datapipe.obs_column_names - # datamodule._datapipe.obs_encoders - # adata_manager.register_fields(adata, **kwargs) - # how to etract the information we need from the datamodule + datamodule.get_var_names() # ORI this has to be provided no check otherwise raise error. adata_manager.register_data_module_fields( datamodule, **kwargs ) # here we need a new function for data module + # ORI No need to register here using adata manager. Instead populate dictionary. It will be sufficient. cls.register_manager(adata_manager) # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 1826f55da1..b1d36959eb 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1050,6 +1050,24 @@ def test_scvi_inference_custom_dataloader(n_latent: int = 5): _ = model.get_latent_representation(dataloader=dataloader) +def test_scvi_train_custom_dataloader(n_latent: int = 5): + # ORI this function could help get started. + adata = synthetic_iid() + SCVI.setup_anndata(adata, batch_key="batch") + + model = SCVI(adata, n_latent=n_latent) + model.train(max_epochs=1) + dataloader = model._make_data_loader(adata) + """ + SCVI.setup_datamodule(dataloader) + # continue from here. Datamodule will always require to pass it into all downstream functions. + model.train(max_epochs=1, datamodule=dataloader) + _ = model.get_elbo(dataloader=dataloader) + _ = model.get_marginal_ll(dataloader=dataloader) + _ = model.get_reconstruction_error(dataloader=dataloader) + _ = model.get_latent_representation(dataloader=dataloader) + """ + def test_scvi_normal_likelihood(): import scanpy as sc From 14f343d655bdad6f6b228beec2c469d38405994f Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Wed, 31 Jul 2024 13:53:44 -0700 Subject: [PATCH 04/53] Changes to datamodule pipeline --- cellxgene-census | 1 + src/scvi/model/_scanvi.py | 4 + src/scvi/model/_scvi.py | 114 +++++++---------------- src/scvi/model/base/_base_model.py | 123 ++++++++++++++++++++----- src/scvi/model/base/_save_load.py | 7 +- src/scvi/model/base/_training_mixin.py | 31 +------ 6 files changed, 142 insertions(+), 138 deletions(-) create mode 160000 cellxgene-census diff --git a/cellxgene-census b/cellxgene-census new file mode 160000 index 0000000000..6edd123100 --- /dev/null +++ b/cellxgene-census @@ -0,0 +1 @@ +Subproject commit 6edd123100716f6a434403b74db58c5379bb0d5d diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 55c6e7a980..8630ff80d2 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -178,6 +178,7 @@ def from_scvi_model( unlabeled_category: str, labels_key: str | None = None, adata: AnnData | None = None, + datamodule: LightningDataModule | None = None, **scanvi_kwargs, ): """Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model. @@ -194,6 +195,8 @@ def from_scvi_model( Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. adata AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + datamodule + LightningDataModule object that has been registered. scanvi_kwargs kwargs for scANVI model """ @@ -242,6 +245,7 @@ def from_scvi_model( **scvi_setup_args, ) scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) + print('TTTT', scanvi_model.registry) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) scanvi_model.was_pretrained = True diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 9ba332f91f..ffc79775a5 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -6,6 +6,7 @@ import numpy as np from anndata import AnnData +from lightning import LightningDataModule from scvi import REGISTRY_KEYS, settings from scvi._types import MinifiedDataType @@ -112,6 +113,7 @@ class SCVI( def __init__( self, adata: AnnData | None = None, + datamodule: LightningDataModule | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -121,7 +123,7 @@ def __init__( latent_distribution: Literal["normal", "ln"] = "normal", **kwargs, ): - super().__init__(adata) + super().__init__(adata, datamodule) self._module_kwargs = { "n_hidden": n_hidden, @@ -140,49 +142,35 @@ def __init__( f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) - # in the next part we need to construct the same module no mather the way - # dataloader was given - if self._module_init_on_train: - # Here we need to adjust given the new custom data loader like CZI case - self.module = None - warnings.warn( - "Model was initialized without `adata`. The module will be initialized when " - "calling `train`. This behavior is experimental and may change in the future.", - UserWarning, - stacklevel=settings.warnings_stacklevel, + n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] + if n_cats_per_cov == 0: + n_cats_per_cov = None + n_batch = self.summary_stats.n_batch + use_size_factor_key = self.registry_['setup_args'][f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key'] + library_log_means, library_log_vars = None, None + if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: + library_log_means, library_log_vars = _init_library_size( + self.adata_manager, n_batch ) - else: - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) - n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - library_log_means, library_log_vars = None, None - if not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) - self.module = self._module_cls( - n_input=self.summary_stats.n_vars, - n_batch=n_batch, - n_labels=self.summary_stats.n_labels, - n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), - n_cats_per_cov=n_cats_per_cov, - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - dropout_rate=dropout_rate, - dispersion=dispersion, - gene_likelihood=gene_likelihood, - latent_distribution=latent_distribution, - use_size_factor_key=use_size_factor_key, - library_log_means=library_log_means, - library_log_vars=library_log_vars, - **kwargs, - ) - self.module.minified_data_type = self.minified_data_type + self.module = self._module_cls( + n_input=self.summary_stats.n_vars, + n_batch=n_batch, + n_labels=self.summary_stats.n_labels, + n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), + n_cats_per_cov=n_cats_per_cov, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + dispersion=dispersion, + gene_likelihood=gene_likelihood, + latent_distribution=latent_distribution, + use_size_factor_key=use_size_factor_key, + library_log_means=library_log_means, + library_log_vars=library_log_vars, + **kwargs, + ) + self.module.minified_data_type = self.minified_data_type self.init_params_ = self._get_init_params(locals()) @@ -257,45 +245,7 @@ def setup_datamodule( %(param_cont_cov_keys)s """ - # Remove these lines. We don't need an adata_manager. - setup_method_args = cls._get_setup_method_args(**locals()) - anndata_fields = [ - LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), - CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), - CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), - NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), - ] - # register new fields if the adata is minified - # adata_minify_type = _get_adata_minify_type(adata) - # if adata_minify_type is not None: - # anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.registry["setup_method_name"] = "setup_datamodule" - - """ - ORI check here the elements are used in the datamodule. - We can stick to their solution for now. But we should check for all setup things whether - they are present in the datamodule. - These checks can adfterwards go to a new class. But implement them here. And ignore all adata things. - We just want to have the same dictionary - """ - if datamodule.get_batch_keys() is not None: - adata_manager.registry["setup_args"]["batch_key"] = datamodule.get_batch_keys() - if datamodule.get_labels_keys() is not None: - adata_manager.registry["setup_args"]["labels_key"] = datamodule.get_labels_keys() - adata_manager.registry["setup_args"]["layer"] = datamodule.datapipe.layer_name - datamodule.get_var_names() # ORI this has to be provided no check otherwise raise error. - adata_manager.register_data_module_fields( - datamodule, **kwargs - ) # here we need a new function for data module - - # ORI No need to register here using adata manager. Instead populate dictionary. It will be sufficient. - cls.register_manager(adata_manager) - # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() - # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] - # pprint(adata_manager.registry) + pass @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index aaef7f84a5..8794679e15 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -9,9 +9,11 @@ from uuid import uuid4 import numpy as np +import pandas as pd import rich import torch from anndata import AnnData +from lightning import LightningDataModule from mudata import MuData from scvi import REGISTRY_KEYS, settings @@ -85,7 +87,7 @@ class BaseModelClass(metaclass=BaseModelMetaClass): _data_loader_cls = AnnDataLoader - def __init__(self, adata: AnnOrMuData | None = None): + def __init__(self, adata: AnnOrMuData | None = None, datamodule: object | None = None): # check if the given adata is minified and check if the model being created # supports minified-data mode (i.e. inherits from the abstract BaseMinifiedModeModelClass). # If not, raise an error to inform the user of the lack of minified-data functionality @@ -98,13 +100,22 @@ def __init__(self, adata: AnnOrMuData | None = None): self.id = str(uuid4()) # Used for cls._manager_store keys. if adata is not None: self._adata = adata + self._datamodule = None self._adata_manager = self._get_most_recent_anndata_manager(adata, required=True) self._register_manager_for_instance(self.adata_manager) # Suffix registry instance variable with _ to include it when saving the model. self.registry_ = self._adata_manager.registry self.summary_stats = self._adata_manager.summary_stats + elif datamodule is not None: + self._adata = None + self._datamodule = datamodule + self._adata_manager = None + # Suffix registry instance variable with _ to include it when saving the model. + self.registry_ = datamodule.registry + self.summary_stats = datamodule.summary_stats + else: + raise ValueError("adata or datamodule must be provided.") - self._module_init_on_train = adata is None self.is_trained_ = False self._model_summary_string = "" self.train_indices_ = None @@ -113,10 +124,20 @@ def __init__(self, adata: AnnOrMuData | None = None): self.history_ = None @property - def adata(self) -> AnnOrMuData: + def adata(self) -> None | AnnOrMuData: """Data attached to model instance.""" return self._adata + @property + def datamodule(self) -> None | LightningDataModule: + """Data attached to model instance.""" + return self._datamodule + + @property + def registry(self) -> dict: + """Data attached to model instance.""" + return self.registry_ + @adata.setter def adata(self, adata: AnnOrMuData): if adata is None: @@ -127,6 +148,14 @@ def adata(self, adata: AnnOrMuData): self.registry_ = self._adata_manager.registry self.summary_stats = self._adata_manager.summary_stats + @datamodule.setter + def datamodule(self, datamodule: LightningDataModule): + if datamodule is None: + raise ValueError("datamodule cannot be None.") + self._datamodule = datamodule + self.registry_ = datamodule.registry + self.summary_stats = datamodule.summary_stats + @property def adata_manager(self) -> AnnDataManager: """Manager instance associated with self.adata.""" @@ -238,6 +267,40 @@ def _register_manager_for_instance(self, adata_manager: AnnDataManager): instance_manager_store = self._per_instance_manager_store[self.id] instance_manager_store[adata_id] = adata_manager + def data_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: + """Returns the object in AnnData associated with the key in the data registry. + + Parameters + ---------- + registry_key + key of object to get from ``self.data_registry`` + + Returns + ------- + The requested data. + """ + if not self.adata: + raise ValueError("self.adata is None. Please register AnnData object to access data.") + else: + return self._adata_manager.get_from_registry(registry_key) + + def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: + """Returns the object in AnnData associated with the key in the data registry. + + Parameters + ---------- + registry_key + key of object to get from ``self.data_registry`` + + Returns + ------- + The requested data. + """ + if not self.adata: + raise ValueError("self.adata is None. Please registry AnnData object.") + else: + return self._adata_manager.get_from_registry(registry_key) + def deregister_manager(self, adata: AnnData | None = None): """Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`. @@ -528,7 +591,7 @@ def _get_user_attributes(self): def _get_init_params(self, locals): """Returns the model init signature with associated passed in values. - Ignores the initial AnnData. + Ignores the initial AnnData or DataModule. """ init = self.__init__ sig = inspect.signature(init) @@ -540,6 +603,8 @@ def _get_init_params(self, locals): k: v for (k, v) in all_params.items() if not isinstance(v, AnnData) and not isinstance(v, MuData) + and not isinstance(v, LightningDataModule) + and k not in ("adata", "datamodule") } # not very efficient but is explicit # separates variable params (**kwargs) from non variable params into two dicts @@ -622,7 +687,10 @@ def save( # save the model state dict and the trainer state dict only model_state_dict = self.module.state_dict() - var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + if self.adata: + var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + else: + var_names = self.datamodule.var_names # get all the user attributes user_attributes = self._get_user_attributes() @@ -645,6 +713,7 @@ def load( cls, dir_path: str, adata: AnnOrMuData | None = None, + datamodule: LightningDataModule | None = None, accelerator: str = "auto", device: int | str = "auto", prefix: str | None = None, @@ -677,7 +746,7 @@ def load( >>> model = ModelClass.load(save_path, adata) >>> model.get_.... """ - load_adata = adata is None + load_adata = adata is None and datamodule is None _, _, device = parse_device_args( accelerator=accelerator, devices=device, @@ -699,31 +768,35 @@ def load( ) adata = new_adata if new_adata is not None else adata - _validate_var_names(adata, var_names) - registry = attr_dict.pop("registry_") + if datamodule is not None: + registry['setup_method_name'] = 'setup_datamodule' + else: + registry['setup_method_name'] = 'setup_anndata' if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." - ) - # Calling ``setup_anndata`` method with the original arguments passed into # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. - method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") - getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) + if adata is not None: + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) + _validate_var_names(adata, var_names) + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) - model = _initialize_model(cls, adata, attr_dict) + model = _initialize_model(cls, adata, datamodule, attr_dict) model.module.on_load(model) model.module.load_state_dict(model_state_dict) model.to_device(device) model.module.eval() - model._validate_anndata(adata) + if adata is not None: + model._validate_anndata(adata) return model @classmethod @@ -816,7 +889,6 @@ def setup_anndata( """ @classmethod - @abstractmethod @setup_anndata_dsp.dedent def setup_datamodule( cls, @@ -903,11 +975,14 @@ class BaseMinifiedModeModelClass(BaseModelClass): @property def minified_data_type(self) -> MinifiedDataType | None: """The type of minified data associated with this model, if applicable.""" - return ( - self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) - if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry - else None - ) + if self.adata_manager: + return ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) + if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry + else None + ) + else: + return None @abstractmethod def minify_adata( diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index 63c41adfda..aa00e807f6 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -97,7 +97,7 @@ def _load_saved_files( return attr_dict, var_names, model_state_dict, adata -def _initialize_model(cls, adata, attr_dict): +def _initialize_model(cls, adata, datamodule, attr_dict): """Helper to initialize a model.""" if "init_params_" not in attr_dict.keys(): raise ValueError( @@ -121,6 +121,9 @@ def _initialize_model(cls, adata, attr_dict): kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} non_kwargs.pop("use_cuda") + # adata and datamodule None is stored in the registry + non_kwargs.pop("adata", None) + non_kwargs.pop("datamodule", None) # backwards compat for scANVI if "unlabeled_category" in non_kwargs.keys(): @@ -128,7 +131,7 @@ def _initialize_model(cls, adata, attr_dict): if "pretrained_model" in non_kwargs.keys(): non_kwargs.pop("pretrained_model") - model = cls(adata, **non_kwargs, **kwargs) + model = cls(adata=adata, datamodule=datamodule, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index 86da6b5019..21c6ed6b59 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -81,25 +81,8 @@ def train( **kwargs Additional keyword arguments passed into :class:`~scvi.train.Trainer`. """ - if datamodule is not None and not self._module_init_on_train: - raise ValueError( - "Cannot pass in `datamodule` if the model was initialized with `adata`." - ) - elif datamodule is None and self._module_init_on_train: - raise ValueError( - "If the model was not initialized with `adata`, a `datamodule` must be passed in." - ) - if max_epochs is None: - if datamodule is None: - max_epochs = get_max_epochs_heuristic(self.adata.n_obs) - elif hasattr(datamodule, "n_obs"): - max_epochs = get_max_epochs_heuristic(datamodule.n_obs) - else: - raise ValueError( - "If `datamodule` does not have `n_obs` attribute, `max_epochs` must be " - "passed in." - ) + max_epochs = get_max_epochs_heuristic(self.summary_stats.n_obs) if datamodule is None: # In the general case we enter here @@ -114,18 +97,6 @@ def train( load_sparse_tensor=load_sparse_tensor, **datasplitter_kwargs, ) - elif self.module is None: - # in CZI case we enter here - self.module = self._module_cls( - datamodule.n_vars, - n_batch=datamodule.n_batch, - n_labels=getattr(datamodule, "n_labels", 1), - n_continuous_cov=getattr(datamodule, "n_continuous_cov", 0), - n_cats_per_cov=getattr(datamodule, "n_cats_per_cov", None), - **self._module_kwargs, - ) - # after either of the cases we should be here with the same self.module - # and same datamodule plan_kwargs = plan_kwargs or {} training_plan = self._training_plan_cls(self.module, **plan_kwargs) From 17282cd57db7a7585f2de4fbb7a2f1aa5e6a0d4d Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Wed, 31 Jul 2024 14:38:18 -0700 Subject: [PATCH 05/53] Fixed attr_dict --- src/scvi/model/base/_archesmixin.py | 66 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 117188f3cf..12253f08ca 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -8,6 +8,7 @@ import pandas as pd import torch from anndata import AnnData +from lightning import LightningDataModule from mudata import MuData from scipy.sparse import csr_matrix @@ -39,8 +40,9 @@ class ArchesMixin: @devices_dsp.dedent def load_query_data( cls, - adata: AnnOrMuData, - reference_model: Union[str, BaseModelClass], + adata: None | AnnOrMuData = None, + reference_model: Union[str, BaseModelClass] = None, + datamodule: None | LightningDataModule = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", device: Union[int, str] = "auto", @@ -83,6 +85,11 @@ def load_query_data( freeze_classifier Whether to freeze classifier completely. Only applies to `SCANVI`. """ + if reference_model is None: + raise ValueError("Please provide a reference model as string or loaded model.") + if adata is None and datamodule is None: + raise ValueError("Please provide either an AnnData or a datamodule.") + _, _, device = parse_device_args( accelerator=accelerator, devices=device, @@ -92,44 +99,45 @@ def load_query_data( attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device) - if isinstance(adata, MuData): - for modality in adata.mod: + if adata is not None: + if isinstance(adata, MuData): + for modality in adata.mod: + if inplace_subset_query_vars: + logger.debug(f"Subsetting {modality} query vars to reference vars.") + adata[modality]._inplace_subset_var(var_names[modality]) + _validate_var_names(adata[modality], var_names[modality]) + + else: if inplace_subset_query_vars: - logger.debug(f"Subsetting {modality} query vars to reference vars.") - adata[modality]._inplace_subset_var(var_names[modality]) - _validate_var_names(adata[modality], var_names[modality]) + logger.debug("Subsetting query vars to reference vars.") + adata._inplace_subset_var(var_names) + _validate_var_names(adata, var_names) - else: if inplace_subset_query_vars: logger.debug("Subsetting query vars to reference vars.") adata._inplace_subset_var(var_names) _validate_var_names(adata, var_names) - if inplace_subset_query_vars: - logger.debug("Subsetting query vars to reference vars.") - adata._inplace_subset_var(var_names) - _validate_var_names(adata, var_names) + registry = attr_dict.pop("registry_") + if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: + raise ValueError("It appears you are loading a model from a different class.") - registry = attr_dict.pop("registry_") - if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: - raise ValueError("It appears you are loading a model from a different class.") + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." + setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) + setup_method( + adata, + source_registry=registry, + extend_categories=True, + allow_missing_labels=True, + **registry[_SETUP_ARGS_KEY], ) - setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) - setup_method( - adata, - source_registry=registry, - extend_categories=True, - allow_missing_labels=True, - **registry[_SETUP_ARGS_KEY], - ) - - model = _initialize_model(cls, adata, attr_dict) + model = _initialize_model(cls, adata, datamodule, attr_dict) adata_manager = model.get_anndata_manager(adata, required=True) if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry: From a4143f520c07b3fbd3c4a56f572ea82a3dad78b4 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 1 Aug 2024 17:54:39 +0300 Subject: [PATCH 06/53] added some fixes based on custom data loader test --- src/scvi/model/base/_archesmixin.py | 14 +- tests/dataloaders/test_custom_dataloader.py | 8 + tests/dataloaders/test_custom_dataloader2.py | 382 ++++++++++++------- tests/model/test_scvi.py | 3 +- 4 files changed, 268 insertions(+), 139 deletions(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 12253f08ca..475ab15661 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -97,7 +97,9 @@ def load_query_data( validate_single_device=True, ) - attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device) + attr_dict, var_names, load_state_dict = _get_loaded_data( + reference_model, device=device, adata=adata + ) if adata is not None: if isinstance(adata, MuData): @@ -216,7 +218,7 @@ def prepare_query_anndata( Query adata ready to use in `load_query_data` unless `return_reference_var_names` in which case a pd.Index of reference var names is returned. """ - _, var_names, _ = _get_loaded_data(reference_model, device="cpu") + _, var_names, _ = _get_loaded_data(reference_model, device="cpu", adata=adata) var_names = pd.Index(var_names) if return_reference_var_names: @@ -364,7 +366,7 @@ def requires_grad(key): par.requires_grad = False -def _get_loaded_data(reference_model, device=None): +def _get_loaded_data(reference_model, device=None, adata=None): if isinstance(reference_model, str): attr_dict, var_names, load_state_dict, _ = _load_saved_files( reference_model, load_adata=False, map_location=device @@ -372,7 +374,11 @@ def _get_loaded_data(reference_model, device=None): else: attr_dict = reference_model._get_user_attributes() attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} - var_names = _get_var_names(reference_model.adata) + var_names = ( + _get_var_names(reference_model.adata) + if attr_dict["registry_"]["setup_method_name"] != "setup_datamodule" + else _get_var_names(adata) + ) load_state_dict = deepcopy(reference_model.module.state_dict()) return attr_dict, var_names, load_state_dict diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index c80cb843b5..8ef4b038af 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from pprint import pprint import numpy as np import scanpy as sc @@ -41,6 +42,11 @@ # Loading the model (just as a compariosn) model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata) +# when loading from disk +scvi.model.SCVI.prepare_query_anndata(adata, model_dir) +# O +scvi.model.SCVI.prepare_query_anndata(adata, model_orig_loaded) + # Obtaining model outputs SCVI_LATENT_KEY = "X_scVI" latent = model_orig.get_latent_representation() @@ -53,6 +59,8 @@ # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] +pprint(adata_manager.registry) + # Plot UMAP and save the figure for later check sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") sc.tl.umap(adata, neighbors_key="scvi") diff --git a/tests/dataloaders/test_custom_dataloader2.py b/tests/dataloaders/test_custom_dataloader2.py index 5b741ebfad..9c3b625d6c 100644 --- a/tests/dataloaders/test_custom_dataloader2.py +++ b/tests/dataloaders/test_custom_dataloader2.py @@ -1,17 +1,23 @@ from __future__ import annotations -import os +import sys + +sys.path.insert(0, "/Users/orikr/Documents/cellxgene-census/api/python/cellxgene_census/src") +sys.path.insert(0, "src") import cellxgene_census -import pandas as pd -import scanpy as sc +import numpy as np import tiledbsoma as soma -import torch +from cellxgene_census.experimental.ml.datamodule import ( + CensusSCVIDataModule, # WE RAN FROM LOCAL LIB +) from cellxgene_census.experimental.pp import highly_variable_genes import scvi -from scvi.dataloaders._custom_dataloader import CensusSCVIDataModule, experiment_dataloader -from scvi.model import SCVI +from scvi.data import _constants, synthetic_iid +from scvi.utils import attrdict + +# cellxgene_census.__file__, scvi.__file__ # We will now create the SCVI model object: # Its parameters: @@ -25,7 +31,6 @@ # The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK # need to pass here new object of registry taht contains everything we will need - # First lets see CELLXGENE example using pytorch loaders implemented now in our repo census = cellxgene_census.open_soma(census_version="stable") experiment_name = "mus_musculus" @@ -54,15 +59,106 @@ batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, ) -# This is a new func to implement -SCVI.setup_datamodule(datamodule) -# -model = SCVI(n_layers=n_layers, n_latent=n_latent, gene_likelihood="nb", encode_covariates=False) +datamodule.vars = hv_idx + + +def _get_summary_stats_from_registry(registry: dict) -> attrdict: + summary_stats = {} + for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): + field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] + summary_stats.update(field_summary_stats) + return attrdict(summary_stats) + + +def setup_datamodule(datamodule: CensusSCVIDataModule): + datamodule.registry = { + "scvi_version": scvi.__version__, + "model_name": "SCVI", + "setup_args": { + "layer": None, + "batch_key": "batch", + "labels_key": None, + "size_factor_key": None, + "categorical_covariate_keys": None, + "continuous_covariate_keys": None, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": datamodule.n_obs, + "n_vars": datamodule.n_vars, + "column_names": datamodule.vars, + }, + "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": datamodule.datapipe.obs_encoders["batch"].classes_, + "original_key": "batch", + }, + "summary_stats": {"n_batch": datamodule.n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": np.array([0]), + "original_key": "_scvi_labels", + }, + "summary_stats": {"n_labels": 1}, + }, + "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } + datamodule.summary_stats = _get_summary_stats_from_registry(datamodule.registry) + datamodule.var_names = [str(i) for i in datamodule.vars] + + +# This is a new func to implement (Implemented Above but we need in our code base as well) +# will take a bit of time to end +setup_datamodule(datamodule) + +# The next part is the same as test_scvi_train_custom_dataloader + +adata = synthetic_iid() +scvi.model.SCVI.setup_anndata(adata, batch_key="batch") +model = scvi.model.SCVI(adata, n_latent=10) +model.train(max_epochs=1) +dataloader = model._make_data_loader(adata) +_ = model.get_elbo(dataloader=dataloader) +_ = model.get_marginal_ll(dataloader=dataloader) +_ = model.get_reconstruction_error(dataloader=dataloader) +_ = model.get_latent_representation(dataloader=dataloader) + +# ORI I broke the code here also for standard models. Please first fix this. - it is fixed +scvi.model.SCVI.prepare_query_anndata(adata, model) +query_model = scvi.model.SCVI.load_query_data(adata, model) + +# We will now create the SCVI model object: +model_census = scvi.model.SCVI( + datamodule=datamodule, + n_layers=n_layers, + n_latent=n_latent, + gene_likelihood="nb", + encode_covariates=False, +) # The CZI data module is a refined data module while SCVI is a lighting datamodule # Altough this is only 1 epoch it will take few mins on local machine -model.train( +model_census.train( datamodule=datamodule, max_epochs=max_epochs, batch_size=batch_size, @@ -73,141 +169,161 @@ # We can now save the trained model. As of the current writing date (June 2024), # scvi-tools doesn't support saving a model that wasn't generated through an AnnData loader, # so we'll use some custom code: -model_state_dict = model.module.state_dict() -var_names = hv_idx.to_numpy() -user_attributes = model._get_user_attributes() -user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} - -user_attributes.update( - { - "n_batch": datamodule.n_batch, - "n_extra_categorical_covs": 0, - "n_extra_continuous_covs": 0, - "n_labels": 1, - "n_vars": datamodule.n_vars, - } +# model_state_dict = model_census.module.state_dict() +# var_names = hv_idx.to_numpy() +# user_attributes = model_census._get_user_attributes() +# user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} +model_census.save("dataloader_model2", overwrite=True) + +# We are now turning this data module back to AnnData +adata = cellxgene_census.get_anndata( + census, + organism=experiment_name, + obs_value_filter=obs_value_filter, ) -with open("model.pt", "wb") as f: - torch.save( - { - "model_state_dict": model_state_dict, - "var_names": var_names, - "attr_dict": user_attributes, - }, - f, - ) +adata = adata[:, datamodule.vars].copy() -# Saving the model the original way -save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory() -model_dir = os.path.join(save_dir, "scvi_czi_model") -model.save(model_dir, overwrite=True) +adata.obs.head() +# ORI Replace this with the function to generate batch key used in the datamodule. +# "12967895-3d58-4e93-be2c-4e1bcf4388d510x 5' v1cellHCA_Mou_3" +adata.obs["batch"] = ("batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str)).astype( + "category" +) +# adata.var_names = 'gene_'+adata.var_names #not sure we need it # We will now load the model back and use it to generate cell embeddings (the latent space), # which can then be used for further analysis. Note that we still need to use some custom code for # loading the model, which includes loading the parameters from the `attr_dict` node stored in # the model. -with open("model.pt", "rb") as f: - torch_model = torch.load(f) - - adict = torch_model["attr_dict"] - params = adict["init_params_"]["non_kwargs"] - - n_batch = adict["n_batch"] - n_extra_categorical_covs = adict["n_extra_categorical_covs"] - n_extra_continuous_covs = adict["n_extra_continuous_covs"] - n_labels = adict["n_labels"] - n_vars = adict["n_vars"] - - latent_distribution = params["latent_distribution"] - dispersion = params["dispersion"] - n_hidden = params["n_hidden"] - dropout_rate = params["dropout_rate"] - gene_likelihood = params["gene_likelihood"] - - model = scvi.model.SCVI( - n_layers=params["n_layers"], - n_latent=params["n_latent"], - gene_likelihood=params["gene_likelihood"], - encode_covariates=False, - ) - - module = model._module_cls( - n_input=n_vars, - n_batch=n_batch, - n_labels=n_labels, - n_continuous_cov=n_extra_continuous_covs, - n_cats_per_cov=None, - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - dropout_rate=dropout_rate, - dispersion=dispersion, - gene_likelihood=gene_likelihood, - latent_distribution=latent_distribution, - ) - model.module = module - - model.module.load_state_dict(torch_model["model_state_dict"]) - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - model.to_device(device) - model.module.eval() - model.is_trained = True - -# We will now generate the cell embeddings for this model, using the `get_latent_representation` -# function available in scvi-tools. -# We can use another instance of the `ExperimentDataPipe` for the forward pass, so we don't need -# to load the whole dataset in memory. -# Needs to have shuffle=False for inference -datamodule_inference = CensusSCVIDataModule( - census["census_data"][experiment_name], - measurement_name="RNA", - X_name="raw", - obs_query=soma.AxisQuery(value_filter=obs_value_filter), - var_query=soma.AxisQuery(coords=(list(hv_idx),)), - batch_size=1024, - shuffle=False, - soma_chunk_size=50_000, - batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], - dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +model_census2 = scvi.model.SCVI.load("dataloader_model2", datamodule=datamodule) +model_census2.setup_anndata(adata, batch_key="batch") +# model_census2.adata = deepcopy(adata) +# ORI Works when loading from disk +scvi.model.SCVI.prepare_query_anndata(adata, "dataloader_model2") +# ORI This one still needs to be fixed. +scvi.model.SCVI.prepare_query_anndata(adata, model_census2) + +# ORI Should work when setting up the AnnData correctly. scANVI with DataModule is not yet +# supported as DataModule can't take a labels_key. +scanvae = scvi.model.SCANVI.from_scvi_model( + model_census2, + adata=adata, + unlabeled_category="Unknown", + labels_key="cell_type", ) -# We can simply feed the datapipe to `get_latent_representation` to obtain the embeddings - -# will take a while -datapipe = datamodule_inference.datapipe -dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) -mapped_dataloader = ( - datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader -) -latent = model.get_latent_representation(dataloader=mapped_dataloader) -emb_idx = datapipe._obs_joinids +# ORI - check it should work with a model initialized with AnnData. See below not fully working yet +model_census3 = scvi.model.SCVI.load("dataloader_model2", adata=adata) -# We will now take a look at the UMAP for the generated embedding -# (will be later comapred to what we got) -adata = cellxgene_census.get_anndata( - census, - organism=experiment_name, - obs_value_filter=obs_value_filter, -) -obs_soma_joinids = adata.obs["soma_joinid"] -obs_indexer = pd.Index(emb_idx) -idx = obs_indexer.get_indexer(obs_soma_joinids) -# Reindexing is necessary to ensure that the cells in the embedding match the -# ones in the anndata object. -adata.obsm["scvi"] = latent[idx] - -# Plot UMAP and save the figure for later check -sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") -sc.tl.umap(adata, neighbors_key="scvi") -sc.pl.umap(adata, color="dataset_id", title="SCVI") +scvi.model.SCVI.prepare_query_anndata(adata, "dataloader_model2") +query_model = scvi.model.SCVI.load_query_data(adata, "dataloader_model2") +scvi.model.SCVI.prepare_query_anndata(adata, model_census3) +query_model = scvi.model.SCVI.load_query_data(adata, model_census3) -# Now return and add all the registry stuff that we will need +# with open("model.pt", "rb") as f: +# torch_model = torch.load(f) +# +# adict = torch_model["attr_dict"] +# params = adict["init_params_"]["non_kwargs"] +# +# n_batch = adict["n_batch"] +# n_extra_categorical_covs = adict["n_extra_categorical_covs"] +# n_extra_continuous_covs = adict["n_extra_continuous_covs"] +# n_labels = adict["n_labels"] +# n_vars = adict["n_vars"] +# +# latent_distribution = params["latent_distribution"] +# dispersion = params["dispersion"] +# n_hidden = params["n_hidden"] +# dropout_rate = params["dropout_rate"] +# gene_likelihood = params["gene_likelihood"] +# +# model = scvi.model.SCVI( +# n_layers=params["n_layers"], +# n_latent=params["n_latent"], +# gene_likelihood=params["gene_likelihood"], +# encode_covariates=False, +# ) +# +# module = model._module_cls( +# n_input=n_vars, +# n_batch=n_batch, +# n_labels=n_labels, +# n_continuous_cov=n_extra_continuous_covs, +# n_cats_per_cov=None, +# n_hidden=n_hidden, +# n_latent=n_latent, +# n_layers=n_layers, +# dropout_rate=dropout_rate, +# dispersion=dispersion, +# gene_likelihood=gene_likelihood, +# latent_distribution=latent_distribution, +# ) +# model.module = module +# +# model.module.load_state_dict(torch_model["model_state_dict"]) +# +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# +# model.to_device(device) +# model.module.eval() +# model.is_trained = True +# We will now generate the cell embeddings for this model, using the `get_latent_representation` +# function available in scvi-tools. +# We can use another instance of the `ExperimentDataPipe` for the forward pass, so we don't need +# to load the whole dataset in memory. -# Now add the missing stuff from the current CZI implemenation in order for us to have the exact -# same steps like the original way (except than setup_anndata) +# # Needs to have shuffle=False for inference +# datamodule_inference = CensusSCVIDataModule( +# census["census_data"][experiment_name], +# measurement_name="RNA", +# X_name="raw", +# obs_query=soma.AxisQuery(value_filter=obs_value_filter), +# var_query=soma.AxisQuery(coords=(list(hv_idx),)), +# batch_size=1024, +# shuffle=False, +# soma_chunk_size=50_000, +# batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], +# dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +# ) +# +# # We can simply feed the datapipe to `get_latent_representation` to obtain the embeddings - +# # will take a while +# datapipe = datamodule_inference.datapipe +# dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) +# mapped_dataloader = ( +# datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader +# ) +# latent = model.get_latent_representation(dataloader=mapped_dataloader) +# emb_idx = datapipe._obs_joinids +# +# # We will now take a look at the UMAP for the generated embedding +# # (will be later comapred to what we got) +# adata = cellxgene_census.get_anndata( +# census, +# organism=experiment_name, +# obs_value_filter=obs_value_filter, +# ) +# obs_soma_joinids = adata.obs["soma_joinid"] +# obs_indexer = pd.Index(emb_idx) +# idx = obs_indexer.get_indexer(obs_soma_joinids) +# # Reindexing is necessary to ensure that the cells in the embedding match the +# # ones in the anndata object. +# adata.obsm["scvi"] = latent[idx] +# +# # Plot UMAP and save the figure for later check +# sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") +# sc.tl.umap(adata, neighbors_key="scvi") +# sc.pl.umap(adata, color="dataset_id", title="SCVI") +# +# +# # Now return and add all the registry stuff that we will need +# +# +# # Now add the missing stuff from the current CZI implemenation in order for us to have the exact +# # same steps like the original way (except than setup_anndata) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index b1d36959eb..dde3a6c115 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1058,7 +1058,6 @@ def test_scvi_train_custom_dataloader(n_latent: int = 5): model = SCVI(adata, n_latent=n_latent) model.train(max_epochs=1) dataloader = model._make_data_loader(adata) - """ SCVI.setup_datamodule(dataloader) # continue from here. Datamodule will always require to pass it into all downstream functions. model.train(max_epochs=1, datamodule=dataloader) @@ -1066,7 +1065,7 @@ def test_scvi_train_custom_dataloader(n_latent: int = 5): _ = model.get_marginal_ll(dataloader=dataloader) _ = model.get_reconstruction_error(dataloader=dataloader) _ = model.get_latent_representation(dataloader=dataloader) - """ + def test_scvi_normal_likelihood(): import scanpy as sc From 69abc478085dd50a3a4b3bc03033d162aab044cf Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Mon, 5 Aug 2024 23:08:13 -0700 Subject: [PATCH 07/53] Changes to dataloader --- src/scvi/data/_manager.py | 169 ------------------ src/scvi/data/_utils.py | 10 ++ src/scvi/model/_scanvi.py | 46 ++--- src/scvi/model/_scvi.py | 6 +- src/scvi/model/base/_archesmixin.py | 20 +-- src/scvi/model/base/_base_model.py | 256 +++++++++++++++++++++------- src/scvi/model/base/_save_load.py | 10 +- 7 files changed, 245 insertions(+), 272 deletions(-) diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 10d0219041..8b1f37b846 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -1,19 +1,14 @@ from __future__ import annotations -import sys from collections import defaultdict from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass -from io import StringIO from uuid import uuid4 import numpy as np import pandas as pd -import rich from mudata import MuData -from rich import box -from rich.console import Console from torch.utils.data import Subset import scvi @@ -292,18 +287,6 @@ def validate(self) -> None: adata, self.adata = self.adata, None # Reset self.adata. self.register_fields(adata, self._source_registry, **self._transfer_kwargs) - def update_setup_method_args(self, setup_method_args: dict): - """Update setup method args. - - Parameters - ---------- - setup_method_args - This is a bit of a misnomer, this is a dict representing kwargs - of the setup method that will be used to update the existing values - in the registry of this instance. - """ - self._registry[_constants._SETUP_ARGS_KEY].update(setup_method_args) - @property def adata_uuid(self) -> str: """Returns the UUID for the AnnData object registered with this instance.""" @@ -311,11 +294,6 @@ def adata_uuid(self) -> str: return self._registry[_constants._SCVI_UUID_KEY] - @property - def registry(self) -> dict: - """Returns the top-level registry dictionary for the AnnData object.""" - return self._registry - @property def data_registry(self) -> attrdict: """Returns the data registry for the AnnData object registered with this instance.""" @@ -369,20 +347,6 @@ def _get_data_registry_from_registry(registry: dict) -> attrdict: data_registry[registry_key] = field_data_registry return attrdict(data_registry) - @property - def summary_stats(self) -> attrdict: - """Returns the summary stats for the AnnData object registered with this instance.""" - self._assert_anndata_registered() - return self._get_summary_stats_from_registry(self._registry) - - @staticmethod - def _get_summary_stats_from_registry(registry: dict) -> attrdict: - summary_stats = {} - for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): - field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] - summary_stats.update(field_summary_stats) - return attrdict(summary_stats) - def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: """Returns the object in AnnData associated with the key in the data registry. @@ -404,136 +368,3 @@ def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: return get_anndata_attribute(self.adata, attr_name, attr_key, mod_key=mod_key) - def get_state_registry(self, registry_key: str) -> attrdict: - """Returns the state registry for the AnnDataField registered with this instance.""" - self._assert_anndata_registered() - - return attrdict( - self._registry[_constants._FIELD_REGISTRIES_KEY][registry_key][ - _constants._STATE_REGISTRY_KEY - ] - ) - - @staticmethod - def _view_summary_stats( - summary_stats: attrdict, as_markdown: bool = False - ) -> rich.table.Table | str: - """Prints summary stats.""" - if not as_markdown: - t = rich.table.Table(title="Summary Statistics") - else: - t = rich.table.Table(box=box.MARKDOWN) - - t.add_column( - "Summary Stat Key", - justify="center", - style="dodger_blue1", - no_wrap=True, - overflow="fold", - ) - t.add_column( - "Value", - justify="center", - style="dark_violet", - no_wrap=True, - overflow="fold", - ) - for stat_key, count in summary_stats.items(): - t.add_row(stat_key, str(count)) - - if as_markdown: - console = Console(file=StringIO(), force_jupyter=False) - console.print(t) - return console.file.getvalue().strip() - - return t - - @staticmethod - def _view_data_registry( - data_registry: attrdict, as_markdown: bool = False - ) -> rich.table.Table | str: - """Prints data registry.""" - if not as_markdown: - t = rich.table.Table(title="Data Registry") - else: - t = rich.table.Table(box=box.MARKDOWN) - - t.add_column( - "Registry Key", - justify="center", - style="dodger_blue1", - no_wrap=True, - overflow="fold", - ) - t.add_column( - "scvi-tools Location", - justify="center", - style="dark_violet", - no_wrap=True, - overflow="fold", - ) - - for registry_key, data_loc in data_registry.items(): - mod_key = getattr(data_loc, _constants._DR_MOD_KEY, None) - attr_name = data_loc.attr_name - attr_key = data_loc.attr_key - scvi_data_str = "adata" - if mod_key is not None: - scvi_data_str += f".mod['{mod_key}']" - if attr_key is None: - scvi_data_str += f".{attr_name}" - else: - scvi_data_str += f".{attr_name}['{attr_key}']" - t.add_row(registry_key, scvi_data_str) - - if as_markdown: - console = Console(file=StringIO(), force_jupyter=False) - console.print(t) - return console.file.getvalue().strip() - - return t - - @staticmethod - def view_setup_method_args(registry: dict) -> None: - """Prints setup kwargs used to produce a given registry. - - Parameters - ---------- - registry - Registry produced by an AnnDataManager. - """ - model_name = registry[_constants._MODEL_NAME_KEY] - setup_args = registry[_constants._SETUP_ARGS_KEY] - if model_name is not None and setup_args is not None: - rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") - rich.pretty.pprint(setup_args) - rich.print() - - def view_registry(self, hide_state_registries: bool = False) -> None: - """Prints summary of the registry. - - Parameters - ---------- - hide_state_registries - If True, prints a shortened summary without details of each state registry. - """ - version = self._registry[_constants._SCVI_VERSION_KEY] - rich.print(f"Anndata setup with scvi-tools version {version}.") - rich.print() - self.view_setup_method_args(self._registry) - - in_colab = "google.colab" in sys.modules - force_jupyter = None if not in_colab else True - console = rich.console.Console(force_jupyter=force_jupyter) - - ss = self._get_summary_stats_from_registry(self._registry) - dr = self._get_data_registry_from_registry(self._registry) - console.print(self._view_summary_stats(ss)) - console.print(self._view_data_registry(dr)) - - if not hide_state_registries: - for field in self.fields: - state_registry = self.get_state_registry(field.registry_key) - t = field.view_state_registry(state_registry) - if t is not None: - console.print(t) diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index 20fbfef293..4e47982a74 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -11,6 +11,8 @@ import scipy.sparse as sp_sparse from anndata import AnnData +from scvi.utils import attrdict + try: # anndata >= 0.10 from anndata.experimental import CSCDataset, CSRDataset @@ -156,6 +158,14 @@ def _set_data_in_registry( setattr(adata, attr_name, attribute) +def _get_summary_stats_from_registry(registry: dict) -> attrdict: + summary_stats = {} + for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): + field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] + summary_stats.update(field_summary_stats) + return attrdict(summary_stats) + + def _verify_and_correct_data_format(adata: AnnData, attr_name: str, attr_key: str | None): """Check data format and correct if necessary. diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 8630ff80d2..5a09cec0f9 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -110,6 +110,7 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): def __init__( self, adata: AnnData, + registry: dict, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -119,24 +120,24 @@ def __init__( linear_classifier: bool = False, **model_kwargs, ): - super().__init__(adata) + super().__init__(adata, registry) scanvae_model_kwargs = dict(model_kwargs) self._set_indices_and_labels() # ignores unlabeled catgegory n_labels = self.summary_stats.n_labels - 1 - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) + n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] + if n_cats_per_cov == 0: + n_cats_per_cov = None n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + use_size_factor_key = self.registry_['setup_args'][f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key'] library_log_means, library_log_vars = None, None - if not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) + if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: + library_log_means, library_log_vars = _init_library_size( + self.adata_manager, n_batch + ) self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -178,7 +179,7 @@ def from_scvi_model( unlabeled_category: str, labels_key: str | None = None, adata: AnnData | None = None, - datamodule: LightningDataModule | None = None, + registry: dict | None = None, **scanvi_kwargs, ): """Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model. @@ -195,8 +196,8 @@ def from_scvi_model( Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. adata AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. - datamodule - LightningDataModule object that has been registered. + registry + Registry of the datamodule used to train scANVI model. scanvi_kwargs kwargs for scANVI model """ @@ -231,7 +232,7 @@ def from_scvi_model( # validate new anndata against old model scvi_model._validate_anndata(adata) - scvi_setup_args = deepcopy(scvi_model.adata_manager.registry[_SETUP_ARGS_KEY]) + scvi_setup_args = deepcopy(scvi_model.registry[_SETUP_ARGS_KEY]) scvi_labels_key = scvi_setup_args["labels_key"] if labels_key is None and scvi_labels_key is None: raise ValueError( @@ -244,8 +245,8 @@ def from_scvi_model( unlabeled_category=unlabeled_category, **scvi_setup_args, ) - scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) - print('TTTT', scanvi_model.registry) + + scanvi_model = cls(adata, scvi_model.registry, **non_kwargs, **kwargs, **scanvi_kwargs) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) scanvi_model.was_pretrained = True @@ -254,7 +255,7 @@ def from_scvi_model( def _set_indices_and_labels(self): """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key self.unlabeled_category_ = labels_state_registry.unlabeled_category @@ -474,12 +475,13 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] # register new fields if the adata is minified - adata_minify_type = _get_adata_minify_type(adata) - if adata_minify_type is not None: - anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) + if adata: + adata_minify_type = _get_adata_minify_type(adata) + if adata_minify_type is not None: + anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index ffc79775a5..7467e61945 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -113,7 +113,7 @@ class SCVI( def __init__( self, adata: AnnData | None = None, - datamodule: LightningDataModule | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -123,7 +123,7 @@ def __init__( latent_distribution: Literal["normal", "ln"] = "normal", **kwargs, ): - super().__init__(adata, datamodule) + super().__init__(adata, registry) self._module_kwargs = { "n_hidden": n_hidden, @@ -146,7 +146,7 @@ def __init__( if n_cats_per_cov == 0: n_cats_per_cov = None n_batch = self.summary_stats.n_batch - use_size_factor_key = self.registry_['setup_args'][f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key'] + use_size_factor_key = self.get_setup_arg(f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key') library_log_means, library_log_vars = None, None if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: library_log_means, library_log_vars = _init_library_size( diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 475ab15661..02607e6f93 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -8,7 +8,6 @@ import pandas as pd import torch from anndata import AnnData -from lightning import LightningDataModule from mudata import MuData from scipy.sparse import csr_matrix @@ -42,7 +41,7 @@ def load_query_data( cls, adata: None | AnnOrMuData = None, reference_model: Union[str, BaseModelClass] = None, - datamodule: None | LightningDataModule = None, + registry: None | dict = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", device: Union[int, str] = "auto", @@ -87,8 +86,8 @@ def load_query_data( """ if reference_model is None: raise ValueError("Please provide a reference model as string or loaded model.") - if adata is None and datamodule is None: - raise ValueError("Please provide either an AnnData or a datamodule.") + if adata is None and registry is None: + raise ValueError("Please provide either an AnnData or a registry dictionary.") _, _, device = parse_device_args( accelerator=accelerator, @@ -139,15 +138,14 @@ def load_query_data( **registry[_SETUP_ARGS_KEY], ) - model = _initialize_model(cls, adata, datamodule, attr_dict) - adata_manager = model.get_anndata_manager(adata, required=True) + model = _initialize_model(cls, adata, registry, attr_dict) - if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry: + if model.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] > 0: raise NotImplementedError( "scArches currently does not support models with extra categorical covariates." ) - version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".") + version_split = model.registry[_constants._SCVI_VERSION_KEY].split(".") if int(version_split[1]) < 8 and int(version_split[0]) == 0: warnings.warn( "Query integration should be performed using models trained with " @@ -374,11 +372,7 @@ def _get_loaded_data(reference_model, device=None, adata=None): else: attr_dict = reference_model._get_user_attributes() attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} - var_names = ( - _get_var_names(reference_model.adata) - if attr_dict["registry_"]["setup_method_name"] != "setup_datamodule" - else _get_var_names(adata) - ) + var_names = reference_model.get_var_names() load_state_dict = deepcopy(reference_model.module.state_dict()) return attr_dict, var_names, load_state_dict diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 8794679e15..3b4cee9fdf 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -3,9 +3,11 @@ import inspect import logging import os +import sys import warnings from abc import ABCMeta, abstractmethod from collections.abc import Sequence +from io import StringIO from uuid import uuid4 import numpy as np @@ -13,20 +15,28 @@ import rich import torch from anndata import AnnData -from lightning import LightningDataModule from mudata import MuData +from rich import box +from rich.console import Console from scvi import REGISTRY_KEYS, settings from scvi._types import AnnOrMuData, MinifiedDataType from scvi.data import AnnDataManager from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( + _FIELD_REGISTRIES_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, + _STATE_REGISTRY_KEY, +) +from scvi.data._utils import ( + _assign_adata_uuid, + _check_if_view, + _get_adata_minify_type, + _get_summary_stats_from_registry, ) -from scvi.data._utils import _assign_adata_uuid, _check_if_view, _get_adata_minify_type from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_device_args from scvi.model.base._constants import SAVE_KEYS @@ -39,6 +49,8 @@ from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +from . import _constants + logger = logging.getLogger(__name__) @@ -87,7 +99,7 @@ class BaseModelClass(metaclass=BaseModelMetaClass): _data_loader_cls = AnnDataLoader - def __init__(self, adata: AnnOrMuData | None = None, datamodule: object | None = None): + def __init__(self, adata: AnnOrMuData | None = None, registry: object | None = None): # check if the given adata is minified and check if the model being created # supports minified-data mode (i.e. inherits from the abstract BaseMinifiedModeModelClass). # If not, raise an error to inform the user of the lack of minified-data functionality @@ -100,21 +112,19 @@ def __init__(self, adata: AnnOrMuData | None = None, datamodule: object | None = self.id = str(uuid4()) # Used for cls._manager_store keys. if adata is not None: self._adata = adata - self._datamodule = None self._adata_manager = self._get_most_recent_anndata_manager(adata, required=True) self._register_manager_for_instance(self.adata_manager) # Suffix registry instance variable with _ to include it when saving the model. - self.registry_ = self._adata_manager.registry - self.summary_stats = self._adata_manager.summary_stats - elif datamodule is not None: + self.registry_ = self._adata_manager._registry + self.summary_stats = _get_summary_stats_from_registry(self.registry_) + elif registry is not None: self._adata = None - self._datamodule = datamodule self._adata_manager = None # Suffix registry instance variable with _ to include it when saving the model. - self.registry_ = datamodule.registry - self.summary_stats = datamodule.summary_stats + self.registry_ = registry + self.summary_stats = _get_summary_stats_from_registry(registry) else: - raise ValueError("adata or datamodule must be provided.") + raise ValueError("adata or registry must be provided.") self.is_trained_ = False self._model_summary_string = "" @@ -128,16 +138,20 @@ def adata(self) -> None | AnnOrMuData: """Data attached to model instance.""" return self._adata - @property - def datamodule(self) -> None | LightningDataModule: - """Data attached to model instance.""" - return self._datamodule - @property def registry(self) -> dict: """Data attached to model instance.""" return self.registry_ + def get_var_names(self, legacy_mudata_format=False) -> dict: + """Variable names of input data.""" + from scvi.model.base._save_load import _get_var_names + if self.adata: + return _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + else: + return self.registry[ + _FIELD_REGISTRIES_KEY]['X'][_STATE_REGISTRY_KEY]['column_names'] + @adata.setter def adata(self, adata: AnnOrMuData): if adata is None: @@ -148,14 +162,6 @@ def adata(self, adata: AnnOrMuData): self.registry_ = self._adata_manager.registry self.summary_stats = self._adata_manager.summary_stats - @datamodule.setter - def datamodule(self, datamodule: LightningDataModule): - if datamodule is None: - raise ValueError("datamodule cannot be None.") - self._datamodule = datamodule - self.registry_ = datamodule.registry - self.summary_stats = datamodule.summary_stats - @property def adata_manager(self) -> AnnDataManager: """Manager instance associated with self.adata.""" @@ -393,6 +399,9 @@ def get_anndata_manager( If True, errors on missing manager. Otherwise, returns None when manager is missing. """ cls = self.__class__ + if not adata: + return None + if _SCVI_UUID_KEY not in adata.uns: if required: raise ValueError( @@ -522,13 +531,20 @@ def _validate_anndata( "Input AnnData not setup with scvi-tools. " + "attempting to transfer AnnData setup" ) - self._register_manager_for_instance(self.adata_manager.transfer_fields(adata)) + self._register_manager_for_instance(self.transfer_fields(adata)) else: # Case where correct AnnDataManager is found, replay registration as necessary. adata_manager.validate() return adata + def transfer_fields(self, adata: AnnOrMuData, **kwargs) -> AnnData: + """Transfer fields from a model to an AnnData object.""" + if self.adata: + return self.adata_manager.transfer_fields(adata, **kwargs) + else: + raise ValueError("Model need to be initialized with AnnData to transfer fields.") + def _check_if_trained(self, warn: bool = True, message: str = _UNTRAINED_WARNING_MESSAGE): """Check if the model is trained. @@ -591,7 +607,7 @@ def _get_user_attributes(self): def _get_init_params(self, locals): """Returns the model init signature with associated passed in values. - Ignores the initial AnnData or DataModule. + Ignores the initial AnnData or Registry. """ init = self.__init__ sig = inspect.signature(init) @@ -603,8 +619,7 @@ def _get_init_params(self, locals): k: v for (k, v) in all_params.items() if not isinstance(v, AnnData) and not isinstance(v, MuData) - and not isinstance(v, LightningDataModule) - and k not in ("adata", "datamodule") + and k not in ("adata", "registry") } # not very efficient but is explicit # separates variable params (**kwargs) from non variable params into two dicts @@ -659,8 +674,6 @@ def save( anndata_write_kwargs Kwargs for :meth:`~anndata.AnnData.write` """ - from scvi.model.base._save_load import _get_var_names - if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: @@ -686,11 +699,7 @@ def save( # save the model state dict and the trainer state dict only model_state_dict = self.module.state_dict() - - if self.adata: - var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) - else: - var_names = self.datamodule.var_names + var_names = self.get_var_names(legacy_mudata_format=legacy_mudata_format) # get all the user attributes user_attributes = self._get_user_attributes() @@ -713,7 +722,6 @@ def load( cls, dir_path: str, adata: AnnOrMuData | None = None, - datamodule: LightningDataModule | None = None, accelerator: str = "auto", device: int | str = "auto", prefix: str | None = None, @@ -730,6 +738,7 @@ def load( It is not necessary to run setup_anndata, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. + If False, will load the model without AnnData. %(param_accelerator)s %(param_device)s prefix @@ -746,7 +755,7 @@ def load( >>> model = ModelClass.load(save_path, adata) >>> model.get_.... """ - load_adata = adata is None and datamodule is None + load_adata = adata is None _, _, device = parse_device_args( accelerator=accelerator, devices=device, @@ -769,17 +778,14 @@ def load( adata = new_adata if new_adata is not None else adata registry = attr_dict.pop("registry_") - if datamodule is not None: - registry['setup_method_name'] = 'setup_datamodule' - else: - registry['setup_method_name'] = 'setup_anndata' + registry['setup_method_name'] = 'setup_anndata' if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") # Calling ``setup_anndata`` method with the original arguments passed into # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. - if adata is not None: + if adata: if _SETUP_ARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " @@ -789,13 +795,13 @@ def load( method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) - model = _initialize_model(cls, adata, datamodule, attr_dict) + model = _initialize_model(cls, adata, registry, attr_dict) model.module.on_load(model) model.module.load_state_dict(model_state_dict) model.to_device(device) model.module.eval() - if adata is not None: + if adata: model._validate_anndata(adata) return model @@ -888,22 +894,6 @@ def setup_anndata( on a model-specific instance of :class:`~scvi.data.AnnDataManager`. """ - @classmethod - @setup_anndata_dsp.dedent - def setup_datamodule( - cls, - datamodule, - *args, - **kwargs, - ): - """%(summary)s. - - Each model class deriving from this class provides parameters to this method - according to its needs. To operate correctly with the model initialization, - the implementation must call :meth:`~scvi.model.base.BaseModelClass.register_manager` - on a model-specific instance of :class:`~scvi.data.AnnDataManager`. - """ - @staticmethod def view_setup_args(dir_path: str, prefix: str | None = None) -> None: """Print args used to setup a saved model. @@ -968,6 +958,152 @@ def view_anndata_setup( ) from err adata_manager.view_registry(hide_state_registries=hide_state_registries) + def view_setup_method_args(self) -> None: + """Prints setup kwargs used to produce a given registry. + + Parameters + ---------- + registry + Registry produced by an AnnDataManager. + """ + model_name = self.registry_[_MODEL_NAME_KEY] + setup_args = self.registry_[_SETUP_ARGS_KEY] + if model_name is not None and setup_args is not None: + rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") + rich.pretty.pprint(setup_args) + rich.print() + + def view_registry(self, hide_state_registries: bool = False) -> None: + """Prints summary of the registry. + + Parameters + ---------- + hide_state_registries + If True, prints a shortened summary without details of each state registry. + """ + version = self.registry_[_SCVI_VERSION_KEY] + rich.print(f"Anndata setup with scvi-tools version {version}.") + rich.print() + self.view_setup_method_args(self._registry) + + in_colab = "google.colab" in sys.modules + force_jupyter = None if not in_colab else True + console = rich.console.Console(force_jupyter=force_jupyter) + + ss = _get_summary_stats_from_registry(self._registry) + dr = self._get_data_registry_from_registry(self._registry) + console.print(self._view_summary_stats(ss)) + console.print(self._view_data_registry(dr)) + + if not hide_state_registries: + for field in self.fields: + state_registry = self.get_state_registry(field.registry_key) + t = field.view_state_registry(state_registry) + if t is not None: + console.print(t) + + def get_state_registry(self, registry_key: str) -> attrdict: + """Returns the state registry for the AnnDataField registered with this instance.""" + return attrdict( + self.registry_[_FIELD_REGISTRIES_KEY][registry_key][ + _STATE_REGISTRY_KEY + ] + ) + + def get_setup_arg(self, setup_arg: str) -> attrdict: + """Returns the string provided to setup of a specific setup_arg.""" + return self.registry_[_SETUP_ARGS_KEY][setup_arg] + + @staticmethod + def _view_summary_stats( + summary_stats: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints summary stats.""" + if not as_markdown: + t = rich.table.Table(title="Summary Statistics") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Summary Stat Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "Value", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + for stat_key, count in summary_stats.items(): + t.add_row(stat_key, str(count)) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + @staticmethod + def _view_data_registry( + data_registry: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints data registry.""" + if not as_markdown: + t = rich.table.Table(title="Data Registry") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Registry Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "scvi-tools Location", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + + for registry_key, data_loc in data_registry.items(): + mod_key = getattr(data_loc, _constants._DR_MOD_KEY, None) + attr_name = data_loc.attr_name + attr_key = data_loc.attr_key + scvi_data_str = "adata" + if mod_key is not None: + scvi_data_str += f".mod['{mod_key}']" + if attr_key is None: + scvi_data_str += f".{attr_name}" + else: + scvi_data_str += f".{attr_name}['{attr_key}']" + t.add_row(registry_key, scvi_data_str) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + def update_setup_method_args(self, setup_method_args: dict): + """Update setup method args. + + Parameters + ---------- + setup_method_args + This is a bit of a misnomer, this is a dict representing kwargs + of the setup method that will be used to update the existing values + in the registry of this instance. + """ + self._registry[_SETUP_ARGS_KEY].update(setup_method_args) class BaseMinifiedModeModelClass(BaseModelClass): """Abstract base class for scvi-tools models that can handle minified data.""" diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index aa00e807f6..02af66efe0 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -97,7 +97,7 @@ def _load_saved_files( return attr_dict, var_names, model_state_dict, adata -def _initialize_model(cls, adata, datamodule, attr_dict): +def _initialize_model(cls, adata, registry, attr_dict): """Helper to initialize a model.""" if "init_params_" not in attr_dict.keys(): raise ValueError( @@ -121,9 +121,6 @@ def _initialize_model(cls, adata, datamodule, attr_dict): kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} non_kwargs.pop("use_cuda") - # adata and datamodule None is stored in the registry - non_kwargs.pop("adata", None) - non_kwargs.pop("datamodule", None) # backwards compat for scANVI if "unlabeled_category" in non_kwargs.keys(): @@ -131,7 +128,10 @@ def _initialize_model(cls, adata, datamodule, attr_dict): if "pretrained_model" in non_kwargs.keys(): non_kwargs.pop("pretrained_model") - model = cls(adata=adata, datamodule=datamodule, **non_kwargs, **kwargs) + if not adata: + adata = None + + model = cls(adata=adata, registry=registry, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) From dc21a3dc09c2c95fb99b45f0a03fa611ea75c006 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 28 Jul 2024 17:04:41 +0300 Subject: [PATCH 08/53] copying CZI custom dataloader into our repo --- src/scvi/dataloaders/_custom_dataloader.py | 1298 +++++++++++++++++++ src/scvi/model/_scvi.py | 1 + tests/dataloaders/test_custom_dataloader.py | 1 + 3 files changed, 1300 insertions(+) create mode 100644 src/scvi/dataloaders/_custom_dataloader.py create mode 100644 tests/dataloaders/test_custom_dataloader.py diff --git a/src/scvi/dataloaders/_custom_dataloader.py b/src/scvi/dataloaders/_custom_dataloader.py new file mode 100644 index 0000000000..b22c697c2a --- /dev/null +++ b/src/scvi/dataloaders/_custom_dataloader.py @@ -0,0 +1,1298 @@ +from __future__ import annotations + +import abc +import gc +import logging +import os +import threading +from collections import deque +from collections.abc import Iterator, Sequence +from concurrent import futures +from concurrent.futures import Future +from contextlib import contextmanager +from datetime import timedelta +from math import ceil +from time import time +from typing import Any, TypeVar + +import numpy as np +import numpy.typing as npt +import pandas as pd +import psutil +import scipy +import tiledbsoma as soma +import torch +import torchdata.datapipes.iter as pipes +from attr import define +from lightning.pytorch import LightningDataModule +from numpy.random import Generator +from scipy import sparse +from sklearn.preprocessing import LabelEncoder +from torch import Tensor +from torch import distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +pytorch_logger = logging.getLogger("cellxgene_census.experimental.pytorch") + +# TODO: Rename to reflect the correct order of the Tensors within the tuple: (X, obs) +ObsAndXDatum = tuple[Tensor, Tensor] +"""Return type of ``ExperimentDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of +``X`` matrix row(s).The Tensors are rank 1 if ``batch_size`` is 1, +otherwise the Tensors are rank 2.""" + +util_logger = logging.getLogger("cellxgene_census.experimental.util") + +_T = TypeVar("_T") + + +DEFAULT_TILEDB_CONFIGURATION: dict[str, Any] = { + # https://docs.tiledb.com/main/how-to/configuration#configuration-parameters + "py.init_buffer_bytes": 1 * 1024**3, + "soma.init_buffer_bytes": 1 * 1024**3, + # S3 requests should not be signed, since we want to allow anonymous access + "vfs.s3.no_sign_request": "true", + "vfs.s3.region": "us-west-2", +} + + +def get_default_soma_context( + tiledb_config: dict[str, Any] | None = None, +) -> soma.options.SOMATileDBContext: + """Return a :class:`tiledbsoma.SOMATileDBContext` with sensible defaults that can be further + + customized by the user. The customized context can then be passed to + :func:`cellxgene_census.open_soma` with the ``context`` argument or to + :meth:`somacore.SOMAObject.open` with the ``context`` argument, such as + :meth:`tiledbsoma.Experiment.open`. Use the :meth:`tiledbsoma.SOMATileDBContext.replace` + method on the returned object to customize its settings further. + + Args: + tiledb_config: + A dictionary of TileDB configuration parameters. If specified, the parameters will + override the defaults. If not specified, the default configuration will be returned. + + Returns + ------- + A :class:`tiledbsoma.SOMATileDBContext` object with sensible defaults. + + Examples + -------- + To reduce the amount of memory used by TileDB-SOMA I/O operations: + + .. highlight:: python + .. code-block:: python + + ctx = cellxgene_census.get_default_soma_context( + tiledb_config={ + "py.init_buffer_bytes": 128 * 1024**2, + "soma.init_buffer_bytes": 128 * 1024**2, + } + ) + c = census.open_soma(uri="s3://my-private-bucket/census/soma", context=ctx) + + To access a copy of the Census located in a private bucket that is located in a different + S3 region, use: + + .. highlight:: python + .. code-block:: python + + ctx = cellxgene_census.get_default_soma_context( + tiledb_config={"vfs.s3.no_sign_request": "false", "vfs.s3.region": "us-east-1"} + ) + c = census.open_soma(uri="s3://my-private-bucket/census/soma", context=ctx) + + Lifecycle: + experimental + """ + tiledb_config = dict(DEFAULT_TILEDB_CONFIGURATION, **(tiledb_config or {})) + return soma.options.SOMATileDBContext().replace(tiledb_config=tiledb_config) + + +class _EagerIterator(Iterator[_T]): + def __init__( + self, + iterator: Iterator[_T], + pool: futures.Executor | None = None, + ): + super().__init__() + self.iterator = iterator + self._pool = pool or futures.ThreadPoolExecutor() + self._own_pool = pool is None + self._future: Future[_T] | None = None + self._begin_next() + + def _begin_next(self) -> None: + self._future = self._pool.submit(self.iterator.__next__) + util_logger.debug("Fetching next iterator element, eagerly") + + def __next__(self) -> _T: + try: + assert self._future + res = self._future.result() + self._begin_next() + return res + except StopIteration: + self._cleanup() + raise + + def _cleanup(self) -> None: + util_logger.debug("Cleaning up eager iterator") + if self._own_pool: + self._pool.shutdown() + + def __del__(self) -> None: + # Ensure the threadpool is cleaned up in the case where the + # iterator is not exhausted. For more information on __del__: + # https://docs.python.org/3/reference/datamodel.html#object.__del__ + self._cleanup() + super_del = getattr(super(), "__del__", lambda: None) + super_del() + + +class _EagerBufferedIterator(Iterator[_T]): + def __init__( + self, + iterator: Iterator[_T], + max_pending: int = 1, + pool: futures.Executor | None = None, + ): + super().__init__() + self.iterator = iterator + self.max_pending = max_pending + self._pool = pool or futures.ThreadPoolExecutor() + self._own_pool = pool is None + self._pending_results: deque[futures.Future[_T]] = deque() + self._lock = threading.Lock() + self._begin_next() + + def __next__(self) -> _T: + try: + res = self._pending_results[0].result() + self._pending_results.popleft() + self._begin_next() + return res + except StopIteration: + self._cleanup() + raise + + def _begin_next(self) -> None: + def _fut_done(fut: futures.Future[_T]) -> None: + util_logger.debug("Finished fetching next iterator element, eagerly") + if fut.exception() is None: + self._begin_next() + + with self._lock: + not_running = len(self._pending_results) == 0 or self._pending_results[-1].done() + if len(self._pending_results) < self.max_pending and not_running: + _future = self._pool.submit(self.iterator.__next__) + util_logger.debug("Fetching next iterator element, eagerly") + _future.add_done_callback(_fut_done) + self._pending_results.append(_future) + assert len(self._pending_results) <= self.max_pending + + def _cleanup(self) -> None: + util_logger.debug("Cleaning up eager iterator") + if self._own_pool: + self._pool.shutdown() + + def __del__(self) -> None: + # Ensure the threadpool is cleaned up in the case where the + # iterator is not exhausted. For more information on __del__: + # https://docs.python.org/3/reference/datamodel.html#object.__del__ + self._cleanup() + super_del = getattr(super(), "__del__", lambda: None) + super_del() + + +class Encoder(abc.ABC): + """Base class for obs encoders. + + To define a custom encoder, two methods must be implemented: + + - ``register``: defines how the encoder will be fitted to the data. + - ``transform``: defines how the encoder will be applied to the data + in order to create an obs_tensor. + + See the implementation of ``DefaultEncoder`` for an example. + """ + + @abc.abstractmethod + def register(self, obs: pd.DataFrame) -> None: + """Register the encoder with obs.""" + pass + + @abc.abstractmethod + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + """Transform the obs DataFrame into a DataFrame of encoded values.""" + pass + + @property + def name(self) -> str: + return self.__class__.__name__ + + +class DefaultEncoder(Encoder): + """Default encoder based on LabelEncoder.""" + + def __init__(self, col: str) -> None: + self._encoder = LabelEncoder() + self.col = col + + def register(self, obs: pd.DataFrame) -> None: + self._encoder.fit(obs[self.col].unique()) + + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + return self._encoder.transform(df[self.col]) # type: ignore + + @property + def name(self) -> str: + return self.col + + @property + def classes_(self): # type: ignore + return self._encoder.classes_ + + +@define +class _SOMAChunk: + """Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the + + respective rows from the ``X`` matrix. + + Lifecycle: + experimental + """ + + obs: pd.DataFrame + X: scipy.sparse.spmatrix + stats: Stats + + def __len__(self) -> int: + return len(self.obs) + + +Encoders = dict[str, LabelEncoder] +"""A dictionary of ``LabelEncoder``s keyed by the ``obs`` column name.""" + + +@define +class Stats: + """Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API. This is useful + + for assessing the read throughput of SOMA data. + + Lifecycle: + experimental + """ + + n_obs: int = 0 + """The total number of obs rows retrieved""" + + nnz: int = 0 + """The total number of values retrieved""" + + elapsed: int = 0 + """The total elapsed time in seconds for retrieving all batches""" + + n_soma_chunks: int = 0 + """The number of chunks retrieved""" + + def __str__(self) -> str: + return ( + f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " + f"elapsed={timedelta(seconds=self.elapsed)}" + ) + + def __add__(self, other: Stats) -> Stats: + self.n_obs += other.n_obs + self.nnz += other.nnz + self.elapsed += other.elapsed + self.n_soma_chunks += other.n_soma_chunks + return self + + +@contextmanager +def _open_experiment( + uri: str, + aws_region: str | None = None, +) -> soma.Experiment: + """Internal method for opening a SOMA ``Experiment`` as a context manager.""" + context = get_default_soma_context().replace( + tiledb_config={"vfs.s3.region": aws_region} if aws_region else {} + ) + + with soma.Experiment.open(uri, context=context) as exp: + yield exp + + +class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]): + """Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class, + + not intended for public use. + """ + + X: soma.SparseNDArray + """A handle to the full X data of the SOMA ``Experiment``""" + + obs_joinids_chunks_iter: Iterator[npt.NDArray[np.int64]] + + var_joinids: npt.NDArray[np.int64] + """The ``var`` joinids to be retrieved from the SOMA ``Experiment``""" + + def __init__( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_column_names: Sequence[str], + obs_joinids_chunked: list[npt.NDArray[np.int64]], + var_joinids: npt.NDArray[np.int64], + shuffle_chunk_count: int | None = None, + shuffle_rng: Generator | None = None, + ): + self.obs = obs + self.X = X + self.obs_column_names = obs_column_names + if shuffle_chunk_count: + assert shuffle_rng is not None + + # At the start of this step, `obs_joinids_chunked` is a list of one dimensional + # numpy arrays. Each numpy array corresponds to a chunk of contiguous rows in `obs`. + # Critically, `obs_joinids_chunked` is randomly ordered where each chunk is + # from a random section of `obs`. + # We then take `shuffle_chunk_count` of these in order, concatenate them into + # a larger numpy array and shuffle this larger numpy array. + # The result is again a list of numpy arrays. + self.obs_joinids_chunks_iter = ( + shuffle_rng.permutation(np.concatenate(grouped_chunks)) + for grouped_chunks in list_split(obs_joinids_chunked, shuffle_chunk_count) + ) + else: + self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) + self.var_joinids = var_joinids + self.shuffle_chunk_count = shuffle_chunk_count + + def __next__(self) -> _SOMAChunk: + pytorch_logger.debug("Retrieving next SOMA chunk...") + start_time = time() + + # If no more chunks to iterate through, raise StopIteration, as all iterators + # do when at end + obs_joinids_chunk = next(self.obs_joinids_chunks_iter) + + obs_batch = ( + self.obs.read( + coords=(obs_joinids_chunk,), + column_names=self.obs_column_names, + ) + .concat() + .to_pandas() + .set_index("soma_joinid") + ) + assert obs_batch.shape[0] == obs_joinids_chunk.shape[0] + + # handle case of empty result (first batch has 0 rows) + if len(obs_batch) == 0: + raise StopIteration + + # reorder obs rows to match obs_joinids_chunk ordering, which may be shuffled + obs_batch = obs_batch.reindex(obs_joinids_chunk, copy=False) + + # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse + # matrix, but the blockwise iteration feature is not used (block_size is set to retrieve + # the chunk as a single block) + scipy_iter = ( + self.X.read(coords=(obs_joinids_chunk, self.var_joinids)) + .blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) + .scipy(compress=True) + ) + X_batch, _ = next(scipy_iter) + assert obs_batch.shape[0] == X_batch.shape[0] + + stats = Stats() + stats.n_obs += X_batch.shape[0] + stats.nnz += X_batch.nnz + stats.elapsed += int(time() - start_time) + stats.n_soma_chunks += 1 + + pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}") + return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) + + +def list_split(arr_list: list[Any], sublist_len: int) -> list[list[Any]]: + """Splits a python list into a list of sublists where each sublist is of size `sublist_len`. + + TODO: Replace with `itertools.batched` when Python 3.12 becomes the minimum supported version. + """ + i = 0 + result = [] + while i < len(arr_list): + if (i + sublist_len) >= len(arr_list): + result.append(arr_list[i:]) + else: + result.append(arr_list[i : i + sublist_len]) + + i += sublist_len + + return result + + +def run_gc() -> tuple[tuple[Any, Any, Any], tuple[Any, Any, Any]]: + proc = psutil.Process(os.getpid()) + + pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + gc.collect() + post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + + pytorch_logger.debug(f"gc: pre={pre_gc}") + pytorch_logger.debug(f"gc: post={post_gc}") + + return pre_gc, post_gc + + +class _ObsAndXIterator(Iterator[ObsAndXDatum]): + """Iterates through a set of ``obs`` and corresponding ``X`` rows, where the rows to be + + returned are specified by the ``obs_tables_iter`` argument. For the specified ``obs` rows, + the corresponding ``X`` data is loaded and joined together. It is returned from this iterator + as 2-tuples of ``X`` and obs Tensors. + + Internally manages the retrieval of data in SOMA-sized chunks, fetching the next chunk of SOMA + data as needed. Supports fetching the data in an eager manner, where the next SOMA chunk is + fetched while the current chunk is being read. This is an internal class, not intended for + public use. + """ + + soma_chunk_iter: _SOMAChunk | None + """The iterator for SOMA chunks of paired obs and X data""" + + soma_chunk: _SOMAChunk | None + """The current SOMA chunk of obs and X data""" + + i: int = -1 + """Index into current obs ``SOMA`` chunk""" + + def __init__( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_column_names: Sequence[str], + obs_joinids_chunked: list[npt.NDArray[np.int64]], + var_joinids: npt.NDArray[np.int64], + batch_size: int, + encoders: list[Encoder], + stats: Stats, + return_sparse_X: bool, + use_eager_fetch: bool, + shuffle_chunk_count: int | None = None, + shuffle_rng: Generator | None = None, + ) -> None: + self.soma_chunk_iter = _ObsAndXSOMAIterator( + obs, + X, + obs_column_names, + obs_joinids_chunked, + var_joinids, + shuffle_chunk_count, + shuffle_rng, + ) + if use_eager_fetch: + self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) + self.soma_chunk = None + self.var_joinids = var_joinids + self.batch_size = batch_size + self.return_sparse_X = return_sparse_X + self.encoders = encoders + self.stats = stats + self.max_process_mem_usage_bytes = 0 + self.X_dtype = X.schema[2].type.to_pandas_dtype() + + def __next__(self) -> ObsAndXDatum: + """Read the next torch batch, possibly across multiple soma chunks.""" + obs: pd.DataFrame = pd.DataFrame() + X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype) + + while len(obs) < self.batch_size: + try: + obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs)) + obs = pd.concat([obs, obs_partial], axis=0) + X = sparse.vstack([X, X_partial]) + except StopIteration: + break + + if len(obs) == 0: + raise StopIteration + + obs_encoded = pd.DataFrame() + + for enc in self.encoders: + obs_encoded[enc.name] = enc.transform(obs) + + # `to_numpy()` avoids copying the numpy array data + obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) + + if not self.return_sparse_X: + X_tensor = torch.from_numpy(X.todense()) + else: + coo = X.tocoo() + + X_tensor = torch.sparse_coo_tensor( + # Note: The `np.array` seems unnecessary, but PyTorch warns bare array + # is "extremely slow" + indices=torch.from_numpy(np.array([coo.row, coo.col])), + values=coo.data, + size=coo.shape, + ) + + if self.batch_size == 1: + X_tensor = X_tensor[0] + obs_tensor = obs_tensor[0] + + return X_tensor, obs_tensor + + def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: + """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size + + batch whose size may contain fewer rows than the requested ``batch_size``. This can happen + when the remaining rows in the current SOMA chunk are fewer than the requested + ``batch_size``. + """ + if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)): + # GC memory from previous soma_chunk + self.soma_chunk = None + mem_info = run_gc() + self.max_process_mem_usage_bytes = max( + self.max_process_mem_usage_bytes, mem_info[0][0].uss + ) + + self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) + self.stats += self.soma_chunk.stats + self.i = 0 + + pytorch_logger.debug(f"Retrieved SOMA chunk totals: {self.stats}") + + obs_batch = self.soma_chunk.obs + X_batch = self.soma_chunk.X + + safe_batch_size = min(batch_size, len(obs_batch) - self.i) + slice_ = slice(self.i, self.i + safe_batch_size) + assert slice_.stop <= obs_batch.shape[0] + + obs_rows = obs_batch.iloc[slice_] + assert obs_rows.index.is_unique + assert safe_batch_size == obs_rows.shape[0] + + X_csr_scipy = X_batch[slice_] + assert obs_rows.shape[0] == X_csr_scipy.shape[0] + + self.i += safe_batch_size + + return obs_rows, X_csr_scipy + + +class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore + r"""An :class:`torchdata.datapipes.iter.IterDataPipe` that reads ``obs`` and ``X`` data from a + + :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` + axes. Provides an iterator over these data when the object is passed to Python's built-in + ``iter`` function. + + >>> for batch in iter(ExperimentDataPipe(...)): + X_batch, y_batch = batch + + The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are + returned in each iteration. If the ``batch_size`` is 1, then each Tensor will have rank 1: + + >>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data + tensor([2415, 0, 0], dtype=torch.int64)) # obs data, encoded + + For larger ``batch_size`` values, the returned Tensors will have rank 2: + + >>> DataLoader(..., batch_size=3, ...): + (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch + [0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), + tensor([[2415, 0, 0], # obs batch + [2416, 0, 4], + [2417, 0, 3]], dtype=torch.int64)) + + The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or + sparse :class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, + this will reduce memory usage. + + The ``obs_column_names`` parameter determines the data columns that are returned in the + ``obs`` Tensor. The first element is always the ``soma_joinid`` of the ``obs`` + :class:`pandas.DataFrame` (or, equivalently, the ``soma_dim_0`` of the ``X`` matrix). + The remaining elements are the ``obs`` columns specified by ``obs_column_names``, + and string-typed columns are encoded as integer values. If needed, these values can be decoded + by obtaining the encoder for a given ``obs`` column name and calling its ``inverse_transform`` + method: + + >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) + + Lifecycle: + experimental + """ + + _initialized: bool + + _obs_joinids: npt.NDArray[np.int64] | None + + _var_joinids: npt.NDArray[np.int64] | None + + _encoders: list[Encoder] + + _stats: Stats + + _shuffle_rng: Generator | None + + # TODO: Consider adding another convenience method wrapper to construct this object whose + # signature is more closely aligned with get_anndata() params + # (i.e. "exploded" AxisQuery params). + def __init__( + self, + experiment: soma.Experiment, + measurement_name: str = "RNA", + X_name: str = "raw", + obs_query: soma.AxisQuery | None = None, + var_query: soma.AxisQuery | None = None, + obs_column_names: Sequence[str] = (), + batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, + return_sparse_X: bool = False, + soma_chunk_size: int | None = 64, + use_eager_fetch: bool = True, + encoders: list[Encoder] | None = None, + shuffle_chunk_count: int | None = 2000, + ) -> None: + r"""Construct a new ``ExperimentDataPipe``. + + Args: + experiment: + The :class:`tiledbsoma.Experiment` from which to read data. + measurement_name: + The name of the :class:`tiledbsoma.Measurement` to read. Defaults to ``"RNA"``. + X_name: + The name of the X layer to read. Defaults to ``"raw"``. + obs_query: + The query used to filter along the ``obs`` axis. If not specified, all ``obs`` and + ``X`` data will be returned, which can be very large. + var_query: + The query used to filter along the ``var`` axis. If not specified, all ``var`` + columns (genes/features) will be returned. + obs_column_names: + The names of the ``obs`` columns to return. The ``soma_joinid`` index "column" does + not need to be specified and will always be returned. If not specified, only the + ``soma_joinid`` will be returned. + batch_size: + The number of rows of ``obs`` and ``X`` data to return in each iteration. Defaults + to ``1``. A value of ``1`` will result in :class:`torch.Tensor` of rank 1 being + returns (a single row); larger values will result in :class:`torch.Tensor`\ s of + rank 2 (multiple rows). + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. + For performance reasons, shuffling is not performed globally across all rows, but + rather in chunks. More specifically, we select ``shuffle_chunk_count`` + non-contiguous chunks across all the observations + in the query, concatenate the chunks and shuffle the associated observations. + The randomness of the shuffling is therefore determined by the + (``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have + been determined to yield a good trade-off between randomness and performance. + Further tuning may be required for different type of models. Note that memory usage + is correlated to the product ``soma_chunk_size * shuffle_chunk_count``. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be + specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure + data partitions are disjoint across worker processes. + return_sparse_X: + Controls whether the ``X`` data is returned as a dense or sparse + :class:`torch.Tensor`. As ``X`` data is very sparse, setting this to ``True`` will + reduce memory usage, if the model supports use of sparse :class:`torch.Tensor`\ s. + Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still experimental + in PyTorch. + soma_chunk_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This + impacts two aspects of this class's behavior: 1) The maximum memory utilization, + with larger values providing better read performance, but also requiring more + memory; 2) The granularity of the global shuffling step (see ``shuffle`` parameter + for details). The default value of 64 works well in conjunction with the default + ``shuffle_chunk_count`` value. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously + fetched SOMA chunk is made available for processing via the iterator. This allows + network (or filesystem) requests to be made in parallel with client-side processing + of the SOMA data, potentially improving overall performance at the cost of + doubling memory utilization. Defaults to ``True``. + shuffle_chunk_count: + The number of contiguous blocks (chunks) of rows sampled to then concatenate + and shuffle. Larger numbers correspond to more randomness per training batch. + If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``. + encoders: + Specify custom encoders to be used. If not specified, a LabelEncoder will be + created and used for each column in ``obs_column_names``. If specified, only + columns for which an encoder has been registered will be returned in the + ``obs`` tensor. + + Lifecycle: + experimental + """ + self.exp_uri = experiment.uri + self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region") + self.measurement_name = measurement_name + self.layer_name = X_name + self.obs_query = obs_query + self.var_query = var_query + self.obs_column_names = obs_column_names + self.batch_size = batch_size + self.return_sparse_X = return_sparse_X + self.soma_chunk_size = soma_chunk_size + self.use_eager_fetch = use_eager_fetch + self._stats = Stats() + self._custom_encoders = encoders + self._encoders = [] + self._obs_joinids = None + self._var_joinids = None + self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None + self._shuffle_rng = np.random.default_rng(seed) if shuffle else None + self._initialized = False + + if "soma_joinid" not in self.obs_column_names: + self.obs_column_names = ["soma_joinid", *self.obs_column_names] + + def _init(self) -> None: + if self._initialized: + return + + pytorch_logger.debug("Initializing ExperimentDataPipe") + + with _open_experiment(self.exp_uri, self.aws_region) as exp: + query = exp.axis_query( + measurement_name=self.measurement_name, + obs_query=self.obs_query, + var_query=self.var_query, + ) + + # The to_numpy() call is a workaround for a possible bug in TileDB-SOMA: + # https://github.com/single-cell-data/TileDB-SOMA/issues/1456 + self._obs_joinids = query.obs_joinids().to_numpy() + self._var_joinids = query.var_joinids().to_numpy() + + self._encoders = self._build_obs_encoders(query) + + self._initialized = True + + @staticmethod + def _subset_ids_to_partition( + ids_chunked: list[npt.NDArray[np.int64]], + partition_index: int, + num_partitions: int, + ) -> list[npt.NDArray[np.int64]]: + """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), + + based upon the current process's distributed rank and world size. + """ + # subset to a single partition + # typing does not reflect that is actually a list of 2D NDArrays + partition_indices = np.array_split(range(len(ids_chunked)), num_partitions) + partition = [ids_chunked[i] for i in partition_indices[partition_index]] + + if pytorch_logger.isEnabledFor(logging.DEBUG) and len(partition) > 0: + pytorch_logger.debug( + f"Process {os.getpid()} handling partition {partition_index + 1} " + f"of {num_partitions}, partition_size={sum([len(chunk) for chunk in partition])}" + ) + + return partition + + @staticmethod + def _compute_partitions( + loader_partition: int, + loader_partitions: int, + dist_partition: int, + num_dist_partitions: int, + ) -> tuple[int, int]: + # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload + total_partitions = num_dist_partitions * loader_partitions + partition = dist_partition * loader_partitions + loader_partition + return partition, total_partitions + + def __iter__(self) -> Iterator[ObsAndXDatum]: + self._init() + assert self._obs_joinids is not None + assert self._var_joinids is not None + + if self.soma_chunk_size is None: + # set soma_chunk_size to utilize ~1 GiB of RAM per SOMA chunk; assumes 95% X data + # sparsity, 8 bytes for the X value and 8 bytes for the sparse matrix indices, + # and a 100% working memory overhead (2x). + X_row_memory_size = 0.05 * len(self._var_joinids) * 8 * 3 * 2 + self.soma_chunk_size = int((1 * 1024**3) / X_row_memory_size) + pytorch_logger.debug(f"Using {self.soma_chunk_size=}") + + if ( + self.return_sparse_X + and torch.utils.data.get_worker_info() + and torch.utils.data.get_worker_info().num_workers > 0 + ): + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) + + # chunk the obs joinids into batches of size soma_chunk_size + obs_joinids_chunked = self._chunk_ids(self._obs_joinids, self.soma_chunk_size) + + # globally shuffle the chunks, if requested + if self._shuffle_rng: + self._shuffle_rng.shuffle(obs_joinids_chunked) + + # subset to a single partition, as needed for distributed training and multi-processing + # data loading + worker_info = torch.utils.data.get_worker_info() + partition, partitions = self._compute_partitions( + loader_partition=worker_info.id if worker_info else 0, + loader_partitions=worker_info.num_workers if worker_info else 1, + dist_partition=dist.get_rank() if dist.is_initialized() else 0, + num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, + ) + obs_joinids_chunked_partition: list[npt.NDArray[np.int64]] = self._subset_ids_to_partition( + obs_joinids_chunked, partition, partitions + ) + + with _open_experiment(self.exp_uri, self.aws_region) as exp: + obs_and_x_iter = _ObsAndXIterator( + obs=exp.obs, + X=exp.ms[self.measurement_name].X[self.layer_name], + obs_column_names=self.obs_column_names, + obs_joinids_chunked=obs_joinids_chunked_partition, + var_joinids=self._var_joinids, + batch_size=self.batch_size, + encoders=self._encoders, + stats=self._stats, + return_sparse_X=self.return_sparse_X, + use_eager_fetch=self.use_eager_fetch, + shuffle_rng=self._shuffle_rng, + shuffle_chunk_count=self._shuffle_chunk_count, + ) + + yield from obs_and_x_iter + + pytorch_logger.debug( + "max process memory usage=" + f"{obs_and_x_iter.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" + ) + + @staticmethod + def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> list[npt.NDArray[np.int64]]: + num_chunks = max(1, ceil(len(ids) / chunk_size)) + pytorch_logger.debug( + f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}" + ) + return np.array_split(ids, num_chunks) + + def __len__(self) -> int: + self._init() + assert self._obs_joinids is not None + + return len(self._obs_joinids) + + def __getitem__(self, index: int) -> ObsAndXDatum: + raise NotImplementedError("IterDataPipe can only be iterated") + + def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> list[Encoder]: + pytorch_logger.debug("Initializing encoders") + + encoders = [] + obs = query.obs(column_names=self.obs_column_names).concat().to_pandas() + + if self._custom_encoders: + # Register all the custom encoders with obs + for enc in self._custom_encoders: + enc.register(obs) + encoders.append(enc) + else: + # Create one DefaultEncoder for each column, and register it with obs + for col in self.obs_column_names: + if obs[col].dtype in [object]: + enc = DefaultEncoder(col) + enc.register(obs) + encoders.append(enc) + + return encoders + + # TODO: This does not work in multiprocessing mode, as child process's stats are not collected + def stats(self) -> Stats: + """Get data loading stats for this + + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + + Returns + ------- + The :class:`cellxgene_census.experimental.ml.pytorch.Stats` object for this + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + + Lifecycle: + experimental + """ + return self._stats + + @property + def shape(self) -> tuple[int, int]: + """Get the shape of the data that will be returned by this + + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + This is the number of obs (cell) and var (feature) counts in the returned data. If used in + multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` + instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the partition of the data assigned to the active process. + + Returns + ------- + A 2-tuple of ``int``s, for obs and var counts, respectively. + + Lifecycle: + experimental + """ + self._init() + assert self._obs_joinids is not None + assert self._var_joinids is not None + + return len(self._obs_joinids), len(self._var_joinids) + + @property + def obs_encoders(self) -> Encoders: + """Returns a dictionary of :class:`sklearn.preprocessing.LabelEncoder` objects, keyed on + + ``obs`` column names, which were used to encode the ``obs`` column values. + + These encoders can be used to decode the encoded values as follows: + + >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) + + Returns + ------- + A ``dict[str, LabelEncoder]``, mapping column names to :class:`sklearn.preprocessing. + LabelEncoder` objects. + """ + self._init() + assert self._encoders is not None + + return {enc.name: enc for enc in self._encoders} + + +# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling +def _collate_noop(x: Any) -> Any: + return x + + +# TODO: Move into somacore.ExperimentAxisQuery +def experiment_dataloader( + datapipe: pipes.IterDataPipe, + num_workers: int = 0, + **dataloader_kwargs: Any, +) -> DataLoader: + """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely + + instantiate a :class:`torch.utils.data.DataLoader` that works with + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`, since some of the + :class:`torch.utils.data.DataLoader` constructor parameters are not applicable when using a + :class:`torchdata.datapipes.iter.IterDataPipe` (``shuffle``, ``batch_size``, ``sampler``, + ``batch_sampler``,``collate_fn``). + + Args: + datapipe: + An :class:`torchdata.datapipes.iter.IterDataPipe`, which can be an + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe` or any other + :class:`torchdata.datapipes.iter.IterDataPipe` that has been chained to the + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + num_workers: + Number of worker processes to use for data loading. If ``0``, data will be loaded in + the main process. + **dataloader_kwargs: + Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` + constructor, except for ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, + and ``collate_fn``, which are not supported when using + :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + + Returns + ------- + A :class:`torch.utils.data.DataLoader`. + + Raises + ------ + ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, + or ``collate_fn`` params are passed as keyword arguments. + + Lifecycle: + experimental + """ + unsupported_dataloader_args = [ + "shuffle", + "batch_size", + "sampler", + "batch_sampler", + "collate_fn", + ] + if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): + raise ValueError( + f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported" + ) + + if num_workers > 0: + _init_multiprocessing() + + return DataLoader( + datapipe, + batch_size=None, # batching is handled by our ExperimentDataPipe + num_workers=num_workers, + # avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches + collate_fn=_collate_noop, + # shuffling is handled by our ExperimentDataPipe + shuffle=False, + **dataloader_kwargs, + ) + + +def _init_multiprocessing() -> None: + """Ensures use of "spawn" for starting child processes with multiprocessing. + + Forked processes are known to be problematic: + https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks + Also, CUDA does not support forked child processes: + https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + + """ + torch.multiprocessing.set_start_method("fork", force=True) + orig_start_method = torch.multiprocessing.get_start_method() + if orig_start_method != "spawn": + if orig_start_method: + pytorch_logger.warning( + "switching torch multiprocessing start method from " + f'"{torch.multiprocessing.get_start_method()}" to "spawn"' + ) + torch.multiprocessing.set_start_method("spawn", force=True) + + +class BatchEncoder(Encoder): + """Concatenates and encodes several columns.""" + + def __init__(self, cols: list[str]): + self.cols = cols + from sklearn.preprocessing import LabelEncoder + + self._encoder = LabelEncoder() + + def transform(self, df: pd.DataFrame): + import functools + + arr = functools.reduce(lambda a, b: a + b, [df[c].astype(str) for c in self.cols]) + return self._encoder.transform(arr) + + def register(self, obs: pd.DataFrame): + import functools + + arr = functools.reduce(lambda a, b: a + b, [obs[c].astype(str) for c in self.cols]) + self._encoder.fit(arr.unique()) + + @property + def name(self) -> str: + return "batch" + + @property + def classes_(self): + return self._encoder.classes_ + + +class CensusSCVIDataModule(LightningDataModule): + """Lightning data module for CxG Census. + + Parameters + ---------- + *args + Positional arguments passed to + :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + batch_keys + List of obs column names concatenated to form the batch column. + train_size + Fraction of data to use for training. + split_seed + Seed for data split. + dataloader_kwargs + Keyword arguments passed into + :func:`~cellxgene_census.experimental.ml.pytorch.experiment_dataloader`. + **kwargs + Additional keyword arguments passed into + :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. Must not include + ``obs_column_names``. + """ + + _TRAIN_KEY = "train" + _VALIDATION_KEY = "validation" + + def __init__( + self, + *args, + batch_keys: list[str] | None = None, + train_size: float | None = None, + split_seed: int | None = None, + dataloader_kwargs: dict[str, any] | None = None, + **kwargs, + ): + super().__init__() + self.datapipe_args = args + self.datapipe_kwargs = kwargs + self.batch_keys = batch_keys + self.train_size = train_size + self.split_seed = split_seed + self.dataloader_kwargs = dataloader_kwargs or {} + + @property + def batch_keys(self) -> list[str]: + """List of obs column names concatenated to form the batch column.""" + if not hasattr(self, "_batch_keys"): + raise AttributeError("`batch_keys` not set.") + return self._batch_keys + + @batch_keys.setter + def batch_keys(self, value: list[str] | None): + if value is None or not isinstance(value, list): + raise ValueError("`batch_keys` must be a list of strings.") + self._batch_keys = value + + @property + def obs_column_names(self) -> list[str]: + """Passed to :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`.""" + if hasattr(self, "_obs_column_names"): + return self._obs_column_names + + obs_column_names = [] + if self.batch_keys is not None: + obs_column_names.extend(self.batch_keys) + + self._obs_column_names = obs_column_names + return self._obs_column_names + + @property + def split_seed(self) -> int: + """Seed for data split.""" + if not hasattr(self, "_split_seed"): + raise AttributeError("`split_seed` not set.") + return self._split_seed + + @split_seed.setter + def split_seed(self, value: int | None): + if value is not None and not isinstance(value, int): + raise ValueError("`split_seed` must be an integer.") + self._split_seed = value or 0 + + @property + def train_size(self) -> float: + """Fraction of data to use for training.""" + if not hasattr(self, "_train_size"): + raise AttributeError("`train_size` not set.") + return self._train_size + + @train_size.setter + def train_size(self, value: float | None): + if value is not None and not isinstance(value, float): + raise ValueError("`train_size` must be a float.") + elif value is not None and (value < 0.0 or value > 1.0): + raise ValueError("`train_size` must be between 0.0 and 1.0.") + self._train_size = value or 1.0 + + @property + def validation_size(self) -> float: + """Fraction of data to use for validation.""" + if not hasattr(self, "_train_size"): + raise AttributeError("`validation_size` not available.") + return 1.0 - self.train_size + + @property + def weights(self) -> dict[str, float]: + """Passed to :meth:`~cellxgene_census.experimental.ml.ExperimentDataPipe.random_split`.""" + if not hasattr(self, "_weights"): + self._weights = {self._TRAIN_KEY: self.train_size} + if self.validation_size > 0.0: + self._weights[self._VALIDATION_KEY] = self.validation_size + return self._weights + + @property + def datapipe(self) -> ExperimentDataPipe: + """Experiment data pipe.""" + if not hasattr(self, "_datapipe"): + encoder = BatchEncoder(self.obs_column_names) + self._datapipe = ExperimentDataPipe( + *self.datapipe_args, + obs_column_names=self.obs_column_names, + encoders=[encoder], + **self.datapipe_kwargs, + ) + return self._datapipe + + def setup(self, stage: str | None = None): + """Set up the train and validation data pipes.""" + datapipes = self.datapipe.random_split(weights=self.weights, seed=self.split_seed) + self._train_datapipe = datapipes[0] + if self.validation_size > 0.0: + self._validation_datapipe = datapipes[1] + else: + self._validation_datapipe = None + + def train_dataloader(self): + """Training data loader.""" + return experiment_dataloader(self._train_datapipe, **self.dataloader_kwargs) + + def val_dataloader(self): + """Validation data loader.""" + if self._validation_datapipe is not None: + return experiment_dataloader(self._validation_datapipe, **self.dataloader_kwargs) + + @property + def n_obs(self) -> int: + """Number of observations in the query. + + Necessary in scvi-tools to compute a heuristic of ``max_epochs``. + """ + return self.datapipe.shape[0] + + @property + def n_vars(self) -> int: + """Number of features in the query. + + Necessary in scvi-tools to initialize the actual layers in the model. + + """ + return self.datapipe.shape[1] + + @property + def n_batch(self) -> int: + """ + Number of unique batches (after concatenation of ``batch_keys``). Necessary in scvi-tools + + so that the model knows how to one-hot encode batches. + + """ + return self.get_n_classes("batch") + + def get_n_classes(self, key: str) -> int: + """Return the number of classes for a given obs column.""" + return len(self.datapipe.obs_encoders[key].classes_) + + def on_before_batch_transfer( + self, + batch: tuple[torch.Tensor, torch.Tensor], + dataloader_idx: int, + ) -> dict[str, torch.Tensor | None]: + """Format the datapipe output with registry keys for scvi-tools.""" + X, obs = batch + + X_KEY: str = "X" + BATCH_KEY: str = "batch" + LABELS_KEY: str = "labels" + + return { + X_KEY: X, + BATCH_KEY: obs, + LABELS_KEY: None, + } diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 95c88c3541..e9bc5ec694 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -141,6 +141,7 @@ def __init__( ) if self._module_init_on_train: + # Here we need to adjust given the new custom data loader self.module = None warnings.warn( "Model was initialized without `adata`. The module will be initialized when " diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py new file mode 100644 index 0000000000..9d48db4f9f --- /dev/null +++ b/tests/dataloaders/test_custom_dataloader.py @@ -0,0 +1 @@ +from __future__ import annotations From a1098b3f2780b7ef03536f428cda3581922de761 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 30 Jul 2024 16:18:49 +0300 Subject: [PATCH 09/53] added some fixes to the custom dataloader stuff --- src/scvi/data/_manager.py | 49 +++++ src/scvi/model/_scvi.py | 66 +++++- src/scvi/model/base/_base_model.py | 17 ++ src/scvi/model/base/_training_mixin.py | 4 + tests/dataloaders/test_custom_dataloader.py | 63 ++++++ tests/dataloaders/test_custom_dataloader2.py | 213 +++++++++++++++++++ tests/model/test_scvi.py | 1 + 7 files changed, 412 insertions(+), 1 deletion(-) create mode 100644 tests/dataloaders/test_custom_dataloader2.py diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 10d0219041..88b7be838b 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -192,6 +192,55 @@ def register_fields( self._assign_uuid() self._assign_most_recent_manager_uuid() + def register_data_module_fields( + self, + datamodule, + source_registry: dict | None = None, + **transfer_kwargs, + ): + """Registers each field associated with this instance with the AnnData object. + + Either registers or transfers the setup from `source_setup_dict` if passed in. + Sets ``self.adata``. + + Parameters + ---------- + adata + AnnData object to be registered. + source_registry + Registry created after registering an AnnData using an + :class:`~scvi.data.AnnDataManager` object. + transfer_kwargs + Additional keywords which modify transfer behavior. Only applicable if + ``source_registry`` is set. + """ + if self.adata is not None: + raise AssertionError("Existing AnnData object registered with this Manager instance.") + + if source_registry is None and transfer_kwargs: + raise TypeError( + f"register_fields() got unexpected keyword arguments {transfer_kwargs} passed " + "without a source_registry." + ) + + self._validate_anndata_object(datamodule) + + for field in self.fields: + self._add_field( + field=field, + adata=datamodule, + source_registry=source_registry, + **transfer_kwargs, + ) + + # Save arguments for register_fields. + self._source_registry = deepcopy(source_registry) + self._transfer_kwargs = deepcopy(transfer_kwargs) + + self.adata = datamodule + self._assign_uuid() + self._assign_most_recent_manager_uuid() + def _add_field( self, field: AnnDataField, diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index e9bc5ec694..4f9f645ad7 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -140,8 +140,10 @@ def __init__( f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) + # in the next part we need to construct the same module no mather the way + # dataloader was given if self._module_init_on_train: - # Here we need to adjust given the new custom data loader + # Here we need to adjust given the new custom data loader like CZI case self.module = None warnings.warn( "Model was initialized without `adata`. The module will be initialized when " @@ -225,6 +227,68 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) + # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() + # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] + # pprint(adata_manager.registry) + + @classmethod + @setup_anndata_dsp.dedent + def setup_datamodule( + cls, + datamodule, + layer: str | None = None, + batch_key: str | None = None, + labels_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + **kwargs, + ): + """%(summary)s. + + Parameters + ---------- + %(param_adata)s + %(param_layer)s + %(param_batch_key)s + %(param_labels_key)s + %(param_size_factor_key)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + """ + setup_method_args = cls._get_setup_method_args(**locals()) + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), + ] + # register new fields if the adata is minified + # adata_minify_type = _get_adata_minify_type(adata) + # if adata_minify_type is not None: + # anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager.registry["setup_method_name"] = "setup_datamodule" + adata_manager.registry["setup_args"]["layer"] = datamodule.datapipe.layer_name + adata_manager.registry["setup_args"]["batch_key"] = datamodule.batch_keys + adata_manager.registry["setup_args"]["labels_key"] + adata_manager.registry["setup_args"]["batch_key"] + adata_manager.registry["setup_args"]["batch_key"] + adata_manager.registry["setup_args"]["batch_key"] + # datamodule._datapipe.obs_column_names + # datamodule._datapipe.obs_encoders + # adata_manager.register_fields(adata, **kwargs) + # how to etract the information we need from the datamodule + adata_manager.register_data_module_fields( + datamodule, **kwargs + ) # here we need a new function for data module + + cls.register_manager(adata_manager) + # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() + # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] + # pprint(adata_manager.registry) @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 09d597ef05..338ab0805f 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -817,6 +817,23 @@ def setup_anndata( on a model-specific instance of :class:`~scvi.data.AnnDataManager`. """ + @classmethod + @abstractmethod + @setup_anndata_dsp.dedent + def setup_datamodule( + cls, + datamodule, + *args, + **kwargs, + ): + """%(summary)s. + + Each model class deriving from this class provides parameters to this method + according to its needs. To operate correctly with the model initialization, + the implementation must call :meth:`~scvi.model.base.BaseModelClass.register_manager` + on a model-specific instance of :class:`~scvi.data.AnnDataManager`. + """ + @staticmethod def view_setup_args(dir_path: str, prefix: str | None = None) -> None: """Print args used to setup a saved model. diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index de3efd9fcb..86da6b5019 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -102,6 +102,7 @@ def train( ) if datamodule is None: + # In the general case we enter here datasplitter_kwargs = datasplitter_kwargs or {} datamodule = self._data_splitter_cls( self.adata_manager, @@ -114,6 +115,7 @@ def train( **datasplitter_kwargs, ) elif self.module is None: + # in CZI case we enter here self.module = self._module_cls( datamodule.n_vars, n_batch=datamodule.n_batch, @@ -122,6 +124,8 @@ def train( n_cats_per_cov=getattr(datamodule, "n_cats_per_cov", None), **self._module_kwargs, ) + # after either of the cases we should be here with the same self.module + # and same datamodule plan_kwargs = plan_kwargs or {} training_plan = self._training_plan_cls(self.module, **plan_kwargs) diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 9d48db4f9f..c80cb843b5 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1 +1,64 @@ from __future__ import annotations + +import os + +import numpy as np +import scanpy as sc + +import scvi +from scvi.data import _constants, synthetic_iid +from scvi.model import SCVI + +# We will now create the SCVI model object: +# Its parameters: +n_layers = 1 +n_latent = 10 +batch_size = 1024 +train_size = 0.9 +max_epochs = 1 + + +# COMAPRE TO THE ORIGINAL METHOD!!! - use the same data!!! +# We first create a registry using the orignal way of anndata in order to compare and add +# what is missing +adata = synthetic_iid() +adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) +SCVI.setup_anndata( + adata, + batch_key="batch", + labels_key="labels", + size_factor_key="size_factor", +) +# +model_orig = SCVI(adata, n_latent=n_latent) +model_orig.train(1, check_val_every_n_epoch=1, train_size=0.5) + +# Saving the model +save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory() +model_dir = os.path.join(save_dir, "scvi_orig_model") +model_orig.save(model_dir, overwrite=True) + +# Loading the model (just as a compariosn) +model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata) + +# Obtaining model outputs +SCVI_LATENT_KEY = "X_scVI" +latent = model_orig.get_latent_representation() +adata.obsm[SCVI_LATENT_KEY] = latent +# latent.shape + +# You can see all necessary entries and the structure at +adata_manager = model_orig.adata_manager +model_orig.view_anndata_setup(hide_state_registries=True) +# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() +adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] + +# Plot UMAP and save the figure for later check +sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") +sc.tl.umap(adata, neighbors_key="scvi") +sc.pl.umap(adata, color="dataset_id", title="SCVI") + +# Now return and add all the registry stuff that we will need + +# Now add the missing stuff from the current CZI implemenation in order for us to have the exact +# same steps like the original way (except than setup_anndata) diff --git a/tests/dataloaders/test_custom_dataloader2.py b/tests/dataloaders/test_custom_dataloader2.py new file mode 100644 index 0000000000..5b741ebfad --- /dev/null +++ b/tests/dataloaders/test_custom_dataloader2.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import os + +import cellxgene_census +import pandas as pd +import scanpy as sc +import tiledbsoma as soma +import torch +from cellxgene_census.experimental.pp import highly_variable_genes + +import scvi +from scvi.dataloaders._custom_dataloader import CensusSCVIDataModule, experiment_dataloader +from scvi.model import SCVI + +# We will now create the SCVI model object: +# Its parameters: +n_layers = 1 +n_latent = 10 +batch_size = 1024 +train_size = 0.9 +max_epochs = 1 + +# We have to create a registry without setup_anndata that contains the same elements +# The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK +# need to pass here new object of registry taht contains everything we will need + + +# First lets see CELLXGENE example using pytorch loaders implemented now in our repo +census = cellxgene_census.open_soma(census_version="stable") +experiment_name = "mus_musculus" +obs_value_filter = 'is_primary_data == True and tissue_general in ["spleen"] and nnz >= 300' +top_n_hvg = 8000 +hvg_batch = ["assay", "suspension_type"] +# THIS WILL TAKE FEW MINUTES TO RUN! +query = census["census_data"][experiment_name].axis_query( + measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter) +) +hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch) +hv = hvgs_df.highly_variable +hv_idx = hv[hv].index + +# Now load the custom data module CZI did that now exists in our db +# (and we will later want to elaborate with more info from our original anndata registry) +# This thing is done by the user in any form they want +datamodule = CensusSCVIDataModule( + census["census_data"][experiment_name], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + batch_size=1024, + shuffle=True, + batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +) +# This is a new func to implement +SCVI.setup_datamodule(datamodule) +# +model = SCVI(n_layers=n_layers, n_latent=n_latent, gene_likelihood="nb", encode_covariates=False) + + +# The CZI data module is a refined data module while SCVI is a lighting datamodule +# Altough this is only 1 epoch it will take few mins on local machine +model.train( + datamodule=datamodule, + max_epochs=max_epochs, + batch_size=batch_size, + train_size=train_size, + early_stopping=False, +) + +# We can now save the trained model. As of the current writing date (June 2024), +# scvi-tools doesn't support saving a model that wasn't generated through an AnnData loader, +# so we'll use some custom code: +model_state_dict = model.module.state_dict() +var_names = hv_idx.to_numpy() +user_attributes = model._get_user_attributes() +user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} + +user_attributes.update( + { + "n_batch": datamodule.n_batch, + "n_extra_categorical_covs": 0, + "n_extra_continuous_covs": 0, + "n_labels": 1, + "n_vars": datamodule.n_vars, + } +) + +with open("model.pt", "wb") as f: + torch.save( + { + "model_state_dict": model_state_dict, + "var_names": var_names, + "attr_dict": user_attributes, + }, + f, + ) + +# Saving the model the original way +save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory() +model_dir = os.path.join(save_dir, "scvi_czi_model") +model.save(model_dir, overwrite=True) + + +# We will now load the model back and use it to generate cell embeddings (the latent space), +# which can then be used for further analysis. Note that we still need to use some custom code for +# loading the model, which includes loading the parameters from the `attr_dict` node stored in +# the model. +with open("model.pt", "rb") as f: + torch_model = torch.load(f) + + adict = torch_model["attr_dict"] + params = adict["init_params_"]["non_kwargs"] + + n_batch = adict["n_batch"] + n_extra_categorical_covs = adict["n_extra_categorical_covs"] + n_extra_continuous_covs = adict["n_extra_continuous_covs"] + n_labels = adict["n_labels"] + n_vars = adict["n_vars"] + + latent_distribution = params["latent_distribution"] + dispersion = params["dispersion"] + n_hidden = params["n_hidden"] + dropout_rate = params["dropout_rate"] + gene_likelihood = params["gene_likelihood"] + + model = scvi.model.SCVI( + n_layers=params["n_layers"], + n_latent=params["n_latent"], + gene_likelihood=params["gene_likelihood"], + encode_covariates=False, + ) + + module = model._module_cls( + n_input=n_vars, + n_batch=n_batch, + n_labels=n_labels, + n_continuous_cov=n_extra_continuous_covs, + n_cats_per_cov=None, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + dispersion=dispersion, + gene_likelihood=gene_likelihood, + latent_distribution=latent_distribution, + ) + model.module = module + + model.module.load_state_dict(torch_model["model_state_dict"]) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model.to_device(device) + model.module.eval() + model.is_trained = True + +# We will now generate the cell embeddings for this model, using the `get_latent_representation` +# function available in scvi-tools. +# We can use another instance of the `ExperimentDataPipe` for the forward pass, so we don't need +# to load the whole dataset in memory. + +# Needs to have shuffle=False for inference +datamodule_inference = CensusSCVIDataModule( + census["census_data"][experiment_name], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + batch_size=1024, + shuffle=False, + soma_chunk_size=50_000, + batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +) + +# We can simply feed the datapipe to `get_latent_representation` to obtain the embeddings - +# will take a while +datapipe = datamodule_inference.datapipe +dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) +mapped_dataloader = ( + datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader +) +latent = model.get_latent_representation(dataloader=mapped_dataloader) +emb_idx = datapipe._obs_joinids + +# We will now take a look at the UMAP for the generated embedding +# (will be later comapred to what we got) +adata = cellxgene_census.get_anndata( + census, + organism=experiment_name, + obs_value_filter=obs_value_filter, +) +obs_soma_joinids = adata.obs["soma_joinid"] +obs_indexer = pd.Index(emb_idx) +idx = obs_indexer.get_indexer(obs_soma_joinids) +# Reindexing is necessary to ensure that the cells in the embedding match the +# ones in the anndata object. +adata.obsm["scvi"] = latent[idx] + +# Plot UMAP and save the figure for later check +sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") +sc.tl.umap(adata, neighbors_key="scvi") +sc.pl.umap(adata, color="dataset_id", title="SCVI") + + +# Now return and add all the registry stuff that we will need + + +# Now add the missing stuff from the current CZI implemenation in order for us to have the exact +# same steps like the original way (except than setup_anndata) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 62ddad42f3..1b68e4e172 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -960,6 +960,7 @@ def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): model.train(datamodule=datamodule) model = SCVI(adata, n_latent=5) + # Add an example for external custom dataloader? assert not model._module_init_on_train assert model.module is not None assert hasattr(model, "adata") From b07216b1340587f6fc5d88a5073a4b79e3f876c0 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Tue, 30 Jul 2024 12:20:51 -0700 Subject: [PATCH 10/53] Some suggestions --- src/scvi/data/_manager.py | 49 --------------------------------------- src/scvi/model/_scvi.py | 25 +++++++++++++------- tests/model/test_scvi.py | 18 ++++++++++++++ 3 files changed, 34 insertions(+), 58 deletions(-) diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 88b7be838b..10d0219041 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -192,55 +192,6 @@ def register_fields( self._assign_uuid() self._assign_most_recent_manager_uuid() - def register_data_module_fields( - self, - datamodule, - source_registry: dict | None = None, - **transfer_kwargs, - ): - """Registers each field associated with this instance with the AnnData object. - - Either registers or transfers the setup from `source_setup_dict` if passed in. - Sets ``self.adata``. - - Parameters - ---------- - adata - AnnData object to be registered. - source_registry - Registry created after registering an AnnData using an - :class:`~scvi.data.AnnDataManager` object. - transfer_kwargs - Additional keywords which modify transfer behavior. Only applicable if - ``source_registry`` is set. - """ - if self.adata is not None: - raise AssertionError("Existing AnnData object registered with this Manager instance.") - - if source_registry is None and transfer_kwargs: - raise TypeError( - f"register_fields() got unexpected keyword arguments {transfer_kwargs} passed " - "without a source_registry." - ) - - self._validate_anndata_object(datamodule) - - for field in self.fields: - self._add_field( - field=field, - adata=datamodule, - source_registry=source_registry, - **transfer_kwargs, - ) - - # Save arguments for register_fields. - self._source_registry = deepcopy(source_registry) - self._transfer_kwargs = deepcopy(transfer_kwargs) - - self.adata = datamodule - self._assign_uuid() - self._assign_most_recent_manager_uuid() - def _add_field( self, field: AnnDataField, diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 4f9f645ad7..9ba332f91f 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -256,6 +256,8 @@ def setup_datamodule( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ + + # Remove these lines. We don't need an adata_manager. setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), @@ -271,20 +273,25 @@ def setup_datamodule( # anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.registry["setup_method_name"] = "setup_datamodule" + + """ + ORI check here the elements are used in the datamodule. + We can stick to their solution for now. But we should check for all setup things whether + they are present in the datamodule. + These checks can adfterwards go to a new class. But implement them here. And ignore all adata things. + We just want to have the same dictionary + """ + if datamodule.get_batch_keys() is not None: + adata_manager.registry["setup_args"]["batch_key"] = datamodule.get_batch_keys() + if datamodule.get_labels_keys() is not None: + adata_manager.registry["setup_args"]["labels_key"] = datamodule.get_labels_keys() adata_manager.registry["setup_args"]["layer"] = datamodule.datapipe.layer_name - adata_manager.registry["setup_args"]["batch_key"] = datamodule.batch_keys - adata_manager.registry["setup_args"]["labels_key"] - adata_manager.registry["setup_args"]["batch_key"] - adata_manager.registry["setup_args"]["batch_key"] - adata_manager.registry["setup_args"]["batch_key"] - # datamodule._datapipe.obs_column_names - # datamodule._datapipe.obs_encoders - # adata_manager.register_fields(adata, **kwargs) - # how to etract the information we need from the datamodule + datamodule.get_var_names() # ORI this has to be provided no check otherwise raise error. adata_manager.register_data_module_fields( datamodule, **kwargs ) # here we need a new function for data module + # ORI No need to register here using adata manager. Instead populate dictionary. It will be sufficient. cls.register_manager(adata_manager) # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 1b68e4e172..eb96b8c34f 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1080,6 +1080,24 @@ def test_scvi_inference_custom_dataloader(n_latent: int = 5): _ = model.get_latent_representation(dataloader=dataloader) +def test_scvi_train_custom_dataloader(n_latent: int = 5): + # ORI this function could help get started. + adata = synthetic_iid() + SCVI.setup_anndata(adata, batch_key="batch") + + model = SCVI(adata, n_latent=n_latent) + model.train(max_epochs=1) + dataloader = model._make_data_loader(adata) + """ + SCVI.setup_datamodule(dataloader) + # continue from here. Datamodule will always require to pass it into all downstream functions. + model.train(max_epochs=1, datamodule=dataloader) + _ = model.get_elbo(dataloader=dataloader) + _ = model.get_marginal_ll(dataloader=dataloader) + _ = model.get_reconstruction_error(dataloader=dataloader) + _ = model.get_latent_representation(dataloader=dataloader) + """ + def test_scvi_normal_likelihood(): import scanpy as sc From a578af1f9f4830db4b891a4a20ce90d6de4388a1 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Wed, 31 Jul 2024 13:53:44 -0700 Subject: [PATCH 11/53] Changes to datamodule pipeline --- cellxgene-census | 1 + src/scvi/model/_scanvi.py | 4 + src/scvi/model/_scvi.py | 114 +++++++---------------- src/scvi/model/base/_base_model.py | 123 ++++++++++++++++++++----- src/scvi/model/base/_save_load.py | 7 +- src/scvi/model/base/_training_mixin.py | 31 +------ 6 files changed, 142 insertions(+), 138 deletions(-) create mode 160000 cellxgene-census diff --git a/cellxgene-census b/cellxgene-census new file mode 160000 index 0000000000..6edd123100 --- /dev/null +++ b/cellxgene-census @@ -0,0 +1 @@ +Subproject commit 6edd123100716f6a434403b74db58c5379bb0d5d diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 55c6e7a980..8630ff80d2 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -178,6 +178,7 @@ def from_scvi_model( unlabeled_category: str, labels_key: str | None = None, adata: AnnData | None = None, + datamodule: LightningDataModule | None = None, **scanvi_kwargs, ): """Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model. @@ -194,6 +195,8 @@ def from_scvi_model( Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. adata AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + datamodule + LightningDataModule object that has been registered. scanvi_kwargs kwargs for scANVI model """ @@ -242,6 +245,7 @@ def from_scvi_model( **scvi_setup_args, ) scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) + print('TTTT', scanvi_model.registry) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) scanvi_model.was_pretrained = True diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 9ba332f91f..ffc79775a5 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -6,6 +6,7 @@ import numpy as np from anndata import AnnData +from lightning import LightningDataModule from scvi import REGISTRY_KEYS, settings from scvi._types import MinifiedDataType @@ -112,6 +113,7 @@ class SCVI( def __init__( self, adata: AnnData | None = None, + datamodule: LightningDataModule | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -121,7 +123,7 @@ def __init__( latent_distribution: Literal["normal", "ln"] = "normal", **kwargs, ): - super().__init__(adata) + super().__init__(adata, datamodule) self._module_kwargs = { "n_hidden": n_hidden, @@ -140,49 +142,35 @@ def __init__( f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) - # in the next part we need to construct the same module no mather the way - # dataloader was given - if self._module_init_on_train: - # Here we need to adjust given the new custom data loader like CZI case - self.module = None - warnings.warn( - "Model was initialized without `adata`. The module will be initialized when " - "calling `train`. This behavior is experimental and may change in the future.", - UserWarning, - stacklevel=settings.warnings_stacklevel, + n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] + if n_cats_per_cov == 0: + n_cats_per_cov = None + n_batch = self.summary_stats.n_batch + use_size_factor_key = self.registry_['setup_args'][f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key'] + library_log_means, library_log_vars = None, None + if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: + library_log_means, library_log_vars = _init_library_size( + self.adata_manager, n_batch ) - else: - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) - n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - library_log_means, library_log_vars = None, None - if not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) - self.module = self._module_cls( - n_input=self.summary_stats.n_vars, - n_batch=n_batch, - n_labels=self.summary_stats.n_labels, - n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), - n_cats_per_cov=n_cats_per_cov, - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - dropout_rate=dropout_rate, - dispersion=dispersion, - gene_likelihood=gene_likelihood, - latent_distribution=latent_distribution, - use_size_factor_key=use_size_factor_key, - library_log_means=library_log_means, - library_log_vars=library_log_vars, - **kwargs, - ) - self.module.minified_data_type = self.minified_data_type + self.module = self._module_cls( + n_input=self.summary_stats.n_vars, + n_batch=n_batch, + n_labels=self.summary_stats.n_labels, + n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), + n_cats_per_cov=n_cats_per_cov, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + dispersion=dispersion, + gene_likelihood=gene_likelihood, + latent_distribution=latent_distribution, + use_size_factor_key=use_size_factor_key, + library_log_means=library_log_means, + library_log_vars=library_log_vars, + **kwargs, + ) + self.module.minified_data_type = self.minified_data_type self.init_params_ = self._get_init_params(locals()) @@ -257,45 +245,7 @@ def setup_datamodule( %(param_cont_cov_keys)s """ - # Remove these lines. We don't need an adata_manager. - setup_method_args = cls._get_setup_method_args(**locals()) - anndata_fields = [ - LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), - CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), - CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), - NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), - ] - # register new fields if the adata is minified - # adata_minify_type = _get_adata_minify_type(adata) - # if adata_minify_type is not None: - # anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.registry["setup_method_name"] = "setup_datamodule" - - """ - ORI check here the elements are used in the datamodule. - We can stick to their solution for now. But we should check for all setup things whether - they are present in the datamodule. - These checks can adfterwards go to a new class. But implement them here. And ignore all adata things. - We just want to have the same dictionary - """ - if datamodule.get_batch_keys() is not None: - adata_manager.registry["setup_args"]["batch_key"] = datamodule.get_batch_keys() - if datamodule.get_labels_keys() is not None: - adata_manager.registry["setup_args"]["labels_key"] = datamodule.get_labels_keys() - adata_manager.registry["setup_args"]["layer"] = datamodule.datapipe.layer_name - datamodule.get_var_names() # ORI this has to be provided no check otherwise raise error. - adata_manager.register_data_module_fields( - datamodule, **kwargs - ) # here we need a new function for data module - - # ORI No need to register here using adata manager. Instead populate dictionary. It will be sufficient. - cls.register_manager(adata_manager) - # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() - # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] - # pprint(adata_manager.registry) + pass @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 338ab0805f..d1e6ee0dae 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -9,9 +9,11 @@ from uuid import uuid4 import numpy as np +import pandas as pd import rich import torch from anndata import AnnData +from lightning import LightningDataModule from mudata import MuData from scvi import REGISTRY_KEYS, settings @@ -85,7 +87,7 @@ class BaseModelClass(metaclass=BaseModelMetaClass): _data_loader_cls = AnnDataLoader - def __init__(self, adata: AnnOrMuData | None = None): + def __init__(self, adata: AnnOrMuData | None = None, datamodule: object | None = None): # check if the given adata is minified and check if the model being created # supports minified-data mode (i.e. inherits from the abstract BaseMinifiedModeModelClass). # If not, raise an error to inform the user of the lack of minified-data functionality @@ -98,13 +100,22 @@ def __init__(self, adata: AnnOrMuData | None = None): self.id = str(uuid4()) # Used for cls._manager_store keys. if adata is not None: self._adata = adata + self._datamodule = None self._adata_manager = self._get_most_recent_anndata_manager(adata, required=True) self._register_manager_for_instance(self.adata_manager) # Suffix registry instance variable with _ to include it when saving the model. self.registry_ = self._adata_manager.registry self.summary_stats = self._adata_manager.summary_stats + elif datamodule is not None: + self._adata = None + self._datamodule = datamodule + self._adata_manager = None + # Suffix registry instance variable with _ to include it when saving the model. + self.registry_ = datamodule.registry + self.summary_stats = datamodule.summary_stats + else: + raise ValueError("adata or datamodule must be provided.") - self._module_init_on_train = adata is None self.is_trained_ = False self._model_summary_string = "" self.train_indices_ = None @@ -113,10 +124,20 @@ def __init__(self, adata: AnnOrMuData | None = None): self.history_ = None @property - def adata(self) -> AnnOrMuData: + def adata(self) -> None | AnnOrMuData: """Data attached to model instance.""" return self._adata + @property + def datamodule(self) -> None | LightningDataModule: + """Data attached to model instance.""" + return self._datamodule + + @property + def registry(self) -> dict: + """Data attached to model instance.""" + return self.registry_ + @adata.setter def adata(self, adata: AnnOrMuData): if adata is None: @@ -127,6 +148,14 @@ def adata(self, adata: AnnOrMuData): self.registry_ = self._adata_manager.registry self.summary_stats = self._adata_manager.summary_stats + @datamodule.setter + def datamodule(self, datamodule: LightningDataModule): + if datamodule is None: + raise ValueError("datamodule cannot be None.") + self._datamodule = datamodule + self.registry_ = datamodule.registry + self.summary_stats = datamodule.summary_stats + @property def adata_manager(self) -> AnnDataManager: """Manager instance associated with self.adata.""" @@ -238,6 +267,40 @@ def _register_manager_for_instance(self, adata_manager: AnnDataManager): instance_manager_store = self._per_instance_manager_store[self.id] instance_manager_store[adata_id] = adata_manager + def data_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: + """Returns the object in AnnData associated with the key in the data registry. + + Parameters + ---------- + registry_key + key of object to get from ``self.data_registry`` + + Returns + ------- + The requested data. + """ + if not self.adata: + raise ValueError("self.adata is None. Please register AnnData object to access data.") + else: + return self._adata_manager.get_from_registry(registry_key) + + def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: + """Returns the object in AnnData associated with the key in the data registry. + + Parameters + ---------- + registry_key + key of object to get from ``self.data_registry`` + + Returns + ------- + The requested data. + """ + if not self.adata: + raise ValueError("self.adata is None. Please registry AnnData object.") + else: + return self._adata_manager.get_from_registry(registry_key) + def deregister_manager(self, adata: AnnData | None = None): """Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`. @@ -530,7 +593,7 @@ def _get_user_attributes(self): def _get_init_params(self, locals): """Returns the model init signature with associated passed in values. - Ignores the initial AnnData. + Ignores the initial AnnData or DataModule. """ init = self.__init__ sig = inspect.signature(init) @@ -542,6 +605,8 @@ def _get_init_params(self, locals): k: v for (k, v) in all_params.items() if not isinstance(v, AnnData) and not isinstance(v, MuData) + and not isinstance(v, LightningDataModule) + and k not in ("adata", "datamodule") } # not very efficient but is explicit # separates variable params (**kwargs) from non variable params into two dicts @@ -624,7 +689,10 @@ def save( # save the model state dict and the trainer state dict only model_state_dict = self.module.state_dict() - var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + if self.adata: + var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + else: + var_names = self.datamodule.var_names # get all the user attributes user_attributes = self._get_user_attributes() @@ -647,6 +715,7 @@ def load( cls, dir_path: str, adata: AnnOrMuData | None = None, + datamodule: LightningDataModule | None = None, accelerator: str = "auto", device: int | str = "auto", prefix: str | None = None, @@ -679,7 +748,7 @@ def load( >>> model = ModelClass.load(save_path, adata) >>> model.get_.... """ - load_adata = adata is None + load_adata = adata is None and datamodule is None _, _, device = parse_device_args( accelerator=accelerator, devices=device, @@ -701,31 +770,35 @@ def load( ) adata = new_adata if new_adata is not None else adata - _validate_var_names(adata, var_names) - registry = attr_dict.pop("registry_") + if datamodule is not None: + registry['setup_method_name'] = 'setup_datamodule' + else: + registry['setup_method_name'] = 'setup_anndata' if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." - ) - # Calling ``setup_anndata`` method with the original arguments passed into # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. - method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") - getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) + if adata is not None: + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) + _validate_var_names(adata, var_names) + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) - model = _initialize_model(cls, adata, attr_dict) + model = _initialize_model(cls, adata, datamodule, attr_dict) model.module.on_load(model) model.module.load_state_dict(model_state_dict) model.to_device(device) model.module.eval() - model._validate_anndata(adata) + if adata is not None: + model._validate_anndata(adata) return model @classmethod @@ -818,7 +891,6 @@ def setup_anndata( """ @classmethod - @abstractmethod @setup_anndata_dsp.dedent def setup_datamodule( cls, @@ -905,11 +977,14 @@ class BaseMinifiedModeModelClass(BaseModelClass): @property def minified_data_type(self) -> MinifiedDataType | None: """The type of minified data associated with this model, if applicable.""" - return ( - self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) - if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry - else None - ) + if self.adata_manager: + return ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) + if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry + else None + ) + else: + return None @abstractmethod def minify_adata( diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index 63c41adfda..aa00e807f6 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -97,7 +97,7 @@ def _load_saved_files( return attr_dict, var_names, model_state_dict, adata -def _initialize_model(cls, adata, attr_dict): +def _initialize_model(cls, adata, datamodule, attr_dict): """Helper to initialize a model.""" if "init_params_" not in attr_dict.keys(): raise ValueError( @@ -121,6 +121,9 @@ def _initialize_model(cls, adata, attr_dict): kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} non_kwargs.pop("use_cuda") + # adata and datamodule None is stored in the registry + non_kwargs.pop("adata", None) + non_kwargs.pop("datamodule", None) # backwards compat for scANVI if "unlabeled_category" in non_kwargs.keys(): @@ -128,7 +131,7 @@ def _initialize_model(cls, adata, attr_dict): if "pretrained_model" in non_kwargs.keys(): non_kwargs.pop("pretrained_model") - model = cls(adata, **non_kwargs, **kwargs) + model = cls(adata=adata, datamodule=datamodule, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index 86da6b5019..21c6ed6b59 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -81,25 +81,8 @@ def train( **kwargs Additional keyword arguments passed into :class:`~scvi.train.Trainer`. """ - if datamodule is not None and not self._module_init_on_train: - raise ValueError( - "Cannot pass in `datamodule` if the model was initialized with `adata`." - ) - elif datamodule is None and self._module_init_on_train: - raise ValueError( - "If the model was not initialized with `adata`, a `datamodule` must be passed in." - ) - if max_epochs is None: - if datamodule is None: - max_epochs = get_max_epochs_heuristic(self.adata.n_obs) - elif hasattr(datamodule, "n_obs"): - max_epochs = get_max_epochs_heuristic(datamodule.n_obs) - else: - raise ValueError( - "If `datamodule` does not have `n_obs` attribute, `max_epochs` must be " - "passed in." - ) + max_epochs = get_max_epochs_heuristic(self.summary_stats.n_obs) if datamodule is None: # In the general case we enter here @@ -114,18 +97,6 @@ def train( load_sparse_tensor=load_sparse_tensor, **datasplitter_kwargs, ) - elif self.module is None: - # in CZI case we enter here - self.module = self._module_cls( - datamodule.n_vars, - n_batch=datamodule.n_batch, - n_labels=getattr(datamodule, "n_labels", 1), - n_continuous_cov=getattr(datamodule, "n_continuous_cov", 0), - n_cats_per_cov=getattr(datamodule, "n_cats_per_cov", None), - **self._module_kwargs, - ) - # after either of the cases we should be here with the same self.module - # and same datamodule plan_kwargs = plan_kwargs or {} training_plan = self._training_plan_cls(self.module, **plan_kwargs) From 42434ec04e6a10ff8892bf38b3649c9f792ce448 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Wed, 31 Jul 2024 14:38:18 -0700 Subject: [PATCH 12/53] Fixed attr_dict --- src/scvi/model/base/_archesmixin.py | 66 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 117188f3cf..12253f08ca 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -8,6 +8,7 @@ import pandas as pd import torch from anndata import AnnData +from lightning import LightningDataModule from mudata import MuData from scipy.sparse import csr_matrix @@ -39,8 +40,9 @@ class ArchesMixin: @devices_dsp.dedent def load_query_data( cls, - adata: AnnOrMuData, - reference_model: Union[str, BaseModelClass], + adata: None | AnnOrMuData = None, + reference_model: Union[str, BaseModelClass] = None, + datamodule: None | LightningDataModule = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", device: Union[int, str] = "auto", @@ -83,6 +85,11 @@ def load_query_data( freeze_classifier Whether to freeze classifier completely. Only applies to `SCANVI`. """ + if reference_model is None: + raise ValueError("Please provide a reference model as string or loaded model.") + if adata is None and datamodule is None: + raise ValueError("Please provide either an AnnData or a datamodule.") + _, _, device = parse_device_args( accelerator=accelerator, devices=device, @@ -92,44 +99,45 @@ def load_query_data( attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device) - if isinstance(adata, MuData): - for modality in adata.mod: + if adata is not None: + if isinstance(adata, MuData): + for modality in adata.mod: + if inplace_subset_query_vars: + logger.debug(f"Subsetting {modality} query vars to reference vars.") + adata[modality]._inplace_subset_var(var_names[modality]) + _validate_var_names(adata[modality], var_names[modality]) + + else: if inplace_subset_query_vars: - logger.debug(f"Subsetting {modality} query vars to reference vars.") - adata[modality]._inplace_subset_var(var_names[modality]) - _validate_var_names(adata[modality], var_names[modality]) + logger.debug("Subsetting query vars to reference vars.") + adata._inplace_subset_var(var_names) + _validate_var_names(adata, var_names) - else: if inplace_subset_query_vars: logger.debug("Subsetting query vars to reference vars.") adata._inplace_subset_var(var_names) _validate_var_names(adata, var_names) - if inplace_subset_query_vars: - logger.debug("Subsetting query vars to reference vars.") - adata._inplace_subset_var(var_names) - _validate_var_names(adata, var_names) + registry = attr_dict.pop("registry_") + if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: + raise ValueError("It appears you are loading a model from a different class.") - registry = attr_dict.pop("registry_") - if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: - raise ValueError("It appears you are loading a model from a different class.") + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." + setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) + setup_method( + adata, + source_registry=registry, + extend_categories=True, + allow_missing_labels=True, + **registry[_SETUP_ARGS_KEY], ) - setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) - setup_method( - adata, - source_registry=registry, - extend_categories=True, - allow_missing_labels=True, - **registry[_SETUP_ARGS_KEY], - ) - - model = _initialize_model(cls, adata, attr_dict) + model = _initialize_model(cls, adata, datamodule, attr_dict) adata_manager = model.get_anndata_manager(adata, required=True) if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry: From 3d0c890bd90f81fcfbfc5709fb3c32750d458317 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 1 Aug 2024 17:54:39 +0300 Subject: [PATCH 13/53] added some fixes based on custom data loader test --- src/scvi/model/base/_archesmixin.py | 14 +- tests/dataloaders/test_custom_dataloader.py | 8 + tests/dataloaders/test_custom_dataloader2.py | 382 ++++++++++++------- tests/model/test_scvi.py | 3 +- 4 files changed, 268 insertions(+), 139 deletions(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 12253f08ca..475ab15661 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -97,7 +97,9 @@ def load_query_data( validate_single_device=True, ) - attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device) + attr_dict, var_names, load_state_dict = _get_loaded_data( + reference_model, device=device, adata=adata + ) if adata is not None: if isinstance(adata, MuData): @@ -216,7 +218,7 @@ def prepare_query_anndata( Query adata ready to use in `load_query_data` unless `return_reference_var_names` in which case a pd.Index of reference var names is returned. """ - _, var_names, _ = _get_loaded_data(reference_model, device="cpu") + _, var_names, _ = _get_loaded_data(reference_model, device="cpu", adata=adata) var_names = pd.Index(var_names) if return_reference_var_names: @@ -364,7 +366,7 @@ def requires_grad(key): par.requires_grad = False -def _get_loaded_data(reference_model, device=None): +def _get_loaded_data(reference_model, device=None, adata=None): if isinstance(reference_model, str): attr_dict, var_names, load_state_dict, _ = _load_saved_files( reference_model, load_adata=False, map_location=device @@ -372,7 +374,11 @@ def _get_loaded_data(reference_model, device=None): else: attr_dict = reference_model._get_user_attributes() attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} - var_names = _get_var_names(reference_model.adata) + var_names = ( + _get_var_names(reference_model.adata) + if attr_dict["registry_"]["setup_method_name"] != "setup_datamodule" + else _get_var_names(adata) + ) load_state_dict = deepcopy(reference_model.module.state_dict()) return attr_dict, var_names, load_state_dict diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index c80cb843b5..8ef4b038af 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from pprint import pprint import numpy as np import scanpy as sc @@ -41,6 +42,11 @@ # Loading the model (just as a compariosn) model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata) +# when loading from disk +scvi.model.SCVI.prepare_query_anndata(adata, model_dir) +# O +scvi.model.SCVI.prepare_query_anndata(adata, model_orig_loaded) + # Obtaining model outputs SCVI_LATENT_KEY = "X_scVI" latent = model_orig.get_latent_representation() @@ -53,6 +59,8 @@ # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] +pprint(adata_manager.registry) + # Plot UMAP and save the figure for later check sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") sc.tl.umap(adata, neighbors_key="scvi") diff --git a/tests/dataloaders/test_custom_dataloader2.py b/tests/dataloaders/test_custom_dataloader2.py index 5b741ebfad..9c3b625d6c 100644 --- a/tests/dataloaders/test_custom_dataloader2.py +++ b/tests/dataloaders/test_custom_dataloader2.py @@ -1,17 +1,23 @@ from __future__ import annotations -import os +import sys + +sys.path.insert(0, "/Users/orikr/Documents/cellxgene-census/api/python/cellxgene_census/src") +sys.path.insert(0, "src") import cellxgene_census -import pandas as pd -import scanpy as sc +import numpy as np import tiledbsoma as soma -import torch +from cellxgene_census.experimental.ml.datamodule import ( + CensusSCVIDataModule, # WE RAN FROM LOCAL LIB +) from cellxgene_census.experimental.pp import highly_variable_genes import scvi -from scvi.dataloaders._custom_dataloader import CensusSCVIDataModule, experiment_dataloader -from scvi.model import SCVI +from scvi.data import _constants, synthetic_iid +from scvi.utils import attrdict + +# cellxgene_census.__file__, scvi.__file__ # We will now create the SCVI model object: # Its parameters: @@ -25,7 +31,6 @@ # The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK # need to pass here new object of registry taht contains everything we will need - # First lets see CELLXGENE example using pytorch loaders implemented now in our repo census = cellxgene_census.open_soma(census_version="stable") experiment_name = "mus_musculus" @@ -54,15 +59,106 @@ batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, ) -# This is a new func to implement -SCVI.setup_datamodule(datamodule) -# -model = SCVI(n_layers=n_layers, n_latent=n_latent, gene_likelihood="nb", encode_covariates=False) +datamodule.vars = hv_idx + + +def _get_summary_stats_from_registry(registry: dict) -> attrdict: + summary_stats = {} + for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): + field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] + summary_stats.update(field_summary_stats) + return attrdict(summary_stats) + + +def setup_datamodule(datamodule: CensusSCVIDataModule): + datamodule.registry = { + "scvi_version": scvi.__version__, + "model_name": "SCVI", + "setup_args": { + "layer": None, + "batch_key": "batch", + "labels_key": None, + "size_factor_key": None, + "categorical_covariate_keys": None, + "continuous_covariate_keys": None, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": datamodule.n_obs, + "n_vars": datamodule.n_vars, + "column_names": datamodule.vars, + }, + "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": datamodule.datapipe.obs_encoders["batch"].classes_, + "original_key": "batch", + }, + "summary_stats": {"n_batch": datamodule.n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": np.array([0]), + "original_key": "_scvi_labels", + }, + "summary_stats": {"n_labels": 1}, + }, + "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } + datamodule.summary_stats = _get_summary_stats_from_registry(datamodule.registry) + datamodule.var_names = [str(i) for i in datamodule.vars] + + +# This is a new func to implement (Implemented Above but we need in our code base as well) +# will take a bit of time to end +setup_datamodule(datamodule) + +# The next part is the same as test_scvi_train_custom_dataloader + +adata = synthetic_iid() +scvi.model.SCVI.setup_anndata(adata, batch_key="batch") +model = scvi.model.SCVI(adata, n_latent=10) +model.train(max_epochs=1) +dataloader = model._make_data_loader(adata) +_ = model.get_elbo(dataloader=dataloader) +_ = model.get_marginal_ll(dataloader=dataloader) +_ = model.get_reconstruction_error(dataloader=dataloader) +_ = model.get_latent_representation(dataloader=dataloader) + +# ORI I broke the code here also for standard models. Please first fix this. - it is fixed +scvi.model.SCVI.prepare_query_anndata(adata, model) +query_model = scvi.model.SCVI.load_query_data(adata, model) + +# We will now create the SCVI model object: +model_census = scvi.model.SCVI( + datamodule=datamodule, + n_layers=n_layers, + n_latent=n_latent, + gene_likelihood="nb", + encode_covariates=False, +) # The CZI data module is a refined data module while SCVI is a lighting datamodule # Altough this is only 1 epoch it will take few mins on local machine -model.train( +model_census.train( datamodule=datamodule, max_epochs=max_epochs, batch_size=batch_size, @@ -73,141 +169,161 @@ # We can now save the trained model. As of the current writing date (June 2024), # scvi-tools doesn't support saving a model that wasn't generated through an AnnData loader, # so we'll use some custom code: -model_state_dict = model.module.state_dict() -var_names = hv_idx.to_numpy() -user_attributes = model._get_user_attributes() -user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} - -user_attributes.update( - { - "n_batch": datamodule.n_batch, - "n_extra_categorical_covs": 0, - "n_extra_continuous_covs": 0, - "n_labels": 1, - "n_vars": datamodule.n_vars, - } +# model_state_dict = model_census.module.state_dict() +# var_names = hv_idx.to_numpy() +# user_attributes = model_census._get_user_attributes() +# user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} +model_census.save("dataloader_model2", overwrite=True) + +# We are now turning this data module back to AnnData +adata = cellxgene_census.get_anndata( + census, + organism=experiment_name, + obs_value_filter=obs_value_filter, ) -with open("model.pt", "wb") as f: - torch.save( - { - "model_state_dict": model_state_dict, - "var_names": var_names, - "attr_dict": user_attributes, - }, - f, - ) +adata = adata[:, datamodule.vars].copy() -# Saving the model the original way -save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory() -model_dir = os.path.join(save_dir, "scvi_czi_model") -model.save(model_dir, overwrite=True) +adata.obs.head() +# ORI Replace this with the function to generate batch key used in the datamodule. +# "12967895-3d58-4e93-be2c-4e1bcf4388d510x 5' v1cellHCA_Mou_3" +adata.obs["batch"] = ("batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str)).astype( + "category" +) +# adata.var_names = 'gene_'+adata.var_names #not sure we need it # We will now load the model back and use it to generate cell embeddings (the latent space), # which can then be used for further analysis. Note that we still need to use some custom code for # loading the model, which includes loading the parameters from the `attr_dict` node stored in # the model. -with open("model.pt", "rb") as f: - torch_model = torch.load(f) - - adict = torch_model["attr_dict"] - params = adict["init_params_"]["non_kwargs"] - - n_batch = adict["n_batch"] - n_extra_categorical_covs = adict["n_extra_categorical_covs"] - n_extra_continuous_covs = adict["n_extra_continuous_covs"] - n_labels = adict["n_labels"] - n_vars = adict["n_vars"] - - latent_distribution = params["latent_distribution"] - dispersion = params["dispersion"] - n_hidden = params["n_hidden"] - dropout_rate = params["dropout_rate"] - gene_likelihood = params["gene_likelihood"] - - model = scvi.model.SCVI( - n_layers=params["n_layers"], - n_latent=params["n_latent"], - gene_likelihood=params["gene_likelihood"], - encode_covariates=False, - ) - - module = model._module_cls( - n_input=n_vars, - n_batch=n_batch, - n_labels=n_labels, - n_continuous_cov=n_extra_continuous_covs, - n_cats_per_cov=None, - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - dropout_rate=dropout_rate, - dispersion=dispersion, - gene_likelihood=gene_likelihood, - latent_distribution=latent_distribution, - ) - model.module = module - - model.module.load_state_dict(torch_model["model_state_dict"]) - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - model.to_device(device) - model.module.eval() - model.is_trained = True - -# We will now generate the cell embeddings for this model, using the `get_latent_representation` -# function available in scvi-tools. -# We can use another instance of the `ExperimentDataPipe` for the forward pass, so we don't need -# to load the whole dataset in memory. -# Needs to have shuffle=False for inference -datamodule_inference = CensusSCVIDataModule( - census["census_data"][experiment_name], - measurement_name="RNA", - X_name="raw", - obs_query=soma.AxisQuery(value_filter=obs_value_filter), - var_query=soma.AxisQuery(coords=(list(hv_idx),)), - batch_size=1024, - shuffle=False, - soma_chunk_size=50_000, - batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], - dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +model_census2 = scvi.model.SCVI.load("dataloader_model2", datamodule=datamodule) +model_census2.setup_anndata(adata, batch_key="batch") +# model_census2.adata = deepcopy(adata) +# ORI Works when loading from disk +scvi.model.SCVI.prepare_query_anndata(adata, "dataloader_model2") +# ORI This one still needs to be fixed. +scvi.model.SCVI.prepare_query_anndata(adata, model_census2) + +# ORI Should work when setting up the AnnData correctly. scANVI with DataModule is not yet +# supported as DataModule can't take a labels_key. +scanvae = scvi.model.SCANVI.from_scvi_model( + model_census2, + adata=adata, + unlabeled_category="Unknown", + labels_key="cell_type", ) -# We can simply feed the datapipe to `get_latent_representation` to obtain the embeddings - -# will take a while -datapipe = datamodule_inference.datapipe -dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) -mapped_dataloader = ( - datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader -) -latent = model.get_latent_representation(dataloader=mapped_dataloader) -emb_idx = datapipe._obs_joinids +# ORI - check it should work with a model initialized with AnnData. See below not fully working yet +model_census3 = scvi.model.SCVI.load("dataloader_model2", adata=adata) -# We will now take a look at the UMAP for the generated embedding -# (will be later comapred to what we got) -adata = cellxgene_census.get_anndata( - census, - organism=experiment_name, - obs_value_filter=obs_value_filter, -) -obs_soma_joinids = adata.obs["soma_joinid"] -obs_indexer = pd.Index(emb_idx) -idx = obs_indexer.get_indexer(obs_soma_joinids) -# Reindexing is necessary to ensure that the cells in the embedding match the -# ones in the anndata object. -adata.obsm["scvi"] = latent[idx] - -# Plot UMAP and save the figure for later check -sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") -sc.tl.umap(adata, neighbors_key="scvi") -sc.pl.umap(adata, color="dataset_id", title="SCVI") +scvi.model.SCVI.prepare_query_anndata(adata, "dataloader_model2") +query_model = scvi.model.SCVI.load_query_data(adata, "dataloader_model2") +scvi.model.SCVI.prepare_query_anndata(adata, model_census3) +query_model = scvi.model.SCVI.load_query_data(adata, model_census3) -# Now return and add all the registry stuff that we will need +# with open("model.pt", "rb") as f: +# torch_model = torch.load(f) +# +# adict = torch_model["attr_dict"] +# params = adict["init_params_"]["non_kwargs"] +# +# n_batch = adict["n_batch"] +# n_extra_categorical_covs = adict["n_extra_categorical_covs"] +# n_extra_continuous_covs = adict["n_extra_continuous_covs"] +# n_labels = adict["n_labels"] +# n_vars = adict["n_vars"] +# +# latent_distribution = params["latent_distribution"] +# dispersion = params["dispersion"] +# n_hidden = params["n_hidden"] +# dropout_rate = params["dropout_rate"] +# gene_likelihood = params["gene_likelihood"] +# +# model = scvi.model.SCVI( +# n_layers=params["n_layers"], +# n_latent=params["n_latent"], +# gene_likelihood=params["gene_likelihood"], +# encode_covariates=False, +# ) +# +# module = model._module_cls( +# n_input=n_vars, +# n_batch=n_batch, +# n_labels=n_labels, +# n_continuous_cov=n_extra_continuous_covs, +# n_cats_per_cov=None, +# n_hidden=n_hidden, +# n_latent=n_latent, +# n_layers=n_layers, +# dropout_rate=dropout_rate, +# dispersion=dispersion, +# gene_likelihood=gene_likelihood, +# latent_distribution=latent_distribution, +# ) +# model.module = module +# +# model.module.load_state_dict(torch_model["model_state_dict"]) +# +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# +# model.to_device(device) +# model.module.eval() +# model.is_trained = True +# We will now generate the cell embeddings for this model, using the `get_latent_representation` +# function available in scvi-tools. +# We can use another instance of the `ExperimentDataPipe` for the forward pass, so we don't need +# to load the whole dataset in memory. -# Now add the missing stuff from the current CZI implemenation in order for us to have the exact -# same steps like the original way (except than setup_anndata) +# # Needs to have shuffle=False for inference +# datamodule_inference = CensusSCVIDataModule( +# census["census_data"][experiment_name], +# measurement_name="RNA", +# X_name="raw", +# obs_query=soma.AxisQuery(value_filter=obs_value_filter), +# var_query=soma.AxisQuery(coords=(list(hv_idx),)), +# batch_size=1024, +# shuffle=False, +# soma_chunk_size=50_000, +# batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], +# dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, +# ) +# +# # We can simply feed the datapipe to `get_latent_representation` to obtain the embeddings - +# # will take a while +# datapipe = datamodule_inference.datapipe +# dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) +# mapped_dataloader = ( +# datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader +# ) +# latent = model.get_latent_representation(dataloader=mapped_dataloader) +# emb_idx = datapipe._obs_joinids +# +# # We will now take a look at the UMAP for the generated embedding +# # (will be later comapred to what we got) +# adata = cellxgene_census.get_anndata( +# census, +# organism=experiment_name, +# obs_value_filter=obs_value_filter, +# ) +# obs_soma_joinids = adata.obs["soma_joinid"] +# obs_indexer = pd.Index(emb_idx) +# idx = obs_indexer.get_indexer(obs_soma_joinids) +# # Reindexing is necessary to ensure that the cells in the embedding match the +# # ones in the anndata object. +# adata.obsm["scvi"] = latent[idx] +# +# # Plot UMAP and save the figure for later check +# sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") +# sc.tl.umap(adata, neighbors_key="scvi") +# sc.pl.umap(adata, color="dataset_id", title="SCVI") +# +# +# # Now return and add all the registry stuff that we will need +# +# +# # Now add the missing stuff from the current CZI implemenation in order for us to have the exact +# # same steps like the original way (except than setup_anndata) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index eb96b8c34f..6000846770 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1088,7 +1088,6 @@ def test_scvi_train_custom_dataloader(n_latent: int = 5): model = SCVI(adata, n_latent=n_latent) model.train(max_epochs=1) dataloader = model._make_data_loader(adata) - """ SCVI.setup_datamodule(dataloader) # continue from here. Datamodule will always require to pass it into all downstream functions. model.train(max_epochs=1, datamodule=dataloader) @@ -1096,7 +1095,7 @@ def test_scvi_train_custom_dataloader(n_latent: int = 5): _ = model.get_marginal_ll(dataloader=dataloader) _ = model.get_reconstruction_error(dataloader=dataloader) _ = model.get_latent_representation(dataloader=dataloader) - """ + def test_scvi_normal_likelihood(): import scanpy as sc From eff5b1ead286898bb799411f91e42623865be060 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Mon, 5 Aug 2024 23:08:13 -0700 Subject: [PATCH 14/53] Changes to dataloader --- src/scvi/data/_manager.py | 169 ------------------ src/scvi/data/_utils.py | 10 ++ src/scvi/model/_scanvi.py | 46 ++--- src/scvi/model/_scvi.py | 6 +- src/scvi/model/base/_archesmixin.py | 20 +-- src/scvi/model/base/_base_model.py | 256 +++++++++++++++++++++------- src/scvi/model/base/_save_load.py | 10 +- 7 files changed, 245 insertions(+), 272 deletions(-) diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 10d0219041..8b1f37b846 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -1,19 +1,14 @@ from __future__ import annotations -import sys from collections import defaultdict from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass -from io import StringIO from uuid import uuid4 import numpy as np import pandas as pd -import rich from mudata import MuData -from rich import box -from rich.console import Console from torch.utils.data import Subset import scvi @@ -292,18 +287,6 @@ def validate(self) -> None: adata, self.adata = self.adata, None # Reset self.adata. self.register_fields(adata, self._source_registry, **self._transfer_kwargs) - def update_setup_method_args(self, setup_method_args: dict): - """Update setup method args. - - Parameters - ---------- - setup_method_args - This is a bit of a misnomer, this is a dict representing kwargs - of the setup method that will be used to update the existing values - in the registry of this instance. - """ - self._registry[_constants._SETUP_ARGS_KEY].update(setup_method_args) - @property def adata_uuid(self) -> str: """Returns the UUID for the AnnData object registered with this instance.""" @@ -311,11 +294,6 @@ def adata_uuid(self) -> str: return self._registry[_constants._SCVI_UUID_KEY] - @property - def registry(self) -> dict: - """Returns the top-level registry dictionary for the AnnData object.""" - return self._registry - @property def data_registry(self) -> attrdict: """Returns the data registry for the AnnData object registered with this instance.""" @@ -369,20 +347,6 @@ def _get_data_registry_from_registry(registry: dict) -> attrdict: data_registry[registry_key] = field_data_registry return attrdict(data_registry) - @property - def summary_stats(self) -> attrdict: - """Returns the summary stats for the AnnData object registered with this instance.""" - self._assert_anndata_registered() - return self._get_summary_stats_from_registry(self._registry) - - @staticmethod - def _get_summary_stats_from_registry(registry: dict) -> attrdict: - summary_stats = {} - for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): - field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] - summary_stats.update(field_summary_stats) - return attrdict(summary_stats) - def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: """Returns the object in AnnData associated with the key in the data registry. @@ -404,136 +368,3 @@ def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: return get_anndata_attribute(self.adata, attr_name, attr_key, mod_key=mod_key) - def get_state_registry(self, registry_key: str) -> attrdict: - """Returns the state registry for the AnnDataField registered with this instance.""" - self._assert_anndata_registered() - - return attrdict( - self._registry[_constants._FIELD_REGISTRIES_KEY][registry_key][ - _constants._STATE_REGISTRY_KEY - ] - ) - - @staticmethod - def _view_summary_stats( - summary_stats: attrdict, as_markdown: bool = False - ) -> rich.table.Table | str: - """Prints summary stats.""" - if not as_markdown: - t = rich.table.Table(title="Summary Statistics") - else: - t = rich.table.Table(box=box.MARKDOWN) - - t.add_column( - "Summary Stat Key", - justify="center", - style="dodger_blue1", - no_wrap=True, - overflow="fold", - ) - t.add_column( - "Value", - justify="center", - style="dark_violet", - no_wrap=True, - overflow="fold", - ) - for stat_key, count in summary_stats.items(): - t.add_row(stat_key, str(count)) - - if as_markdown: - console = Console(file=StringIO(), force_jupyter=False) - console.print(t) - return console.file.getvalue().strip() - - return t - - @staticmethod - def _view_data_registry( - data_registry: attrdict, as_markdown: bool = False - ) -> rich.table.Table | str: - """Prints data registry.""" - if not as_markdown: - t = rich.table.Table(title="Data Registry") - else: - t = rich.table.Table(box=box.MARKDOWN) - - t.add_column( - "Registry Key", - justify="center", - style="dodger_blue1", - no_wrap=True, - overflow="fold", - ) - t.add_column( - "scvi-tools Location", - justify="center", - style="dark_violet", - no_wrap=True, - overflow="fold", - ) - - for registry_key, data_loc in data_registry.items(): - mod_key = getattr(data_loc, _constants._DR_MOD_KEY, None) - attr_name = data_loc.attr_name - attr_key = data_loc.attr_key - scvi_data_str = "adata" - if mod_key is not None: - scvi_data_str += f".mod['{mod_key}']" - if attr_key is None: - scvi_data_str += f".{attr_name}" - else: - scvi_data_str += f".{attr_name}['{attr_key}']" - t.add_row(registry_key, scvi_data_str) - - if as_markdown: - console = Console(file=StringIO(), force_jupyter=False) - console.print(t) - return console.file.getvalue().strip() - - return t - - @staticmethod - def view_setup_method_args(registry: dict) -> None: - """Prints setup kwargs used to produce a given registry. - - Parameters - ---------- - registry - Registry produced by an AnnDataManager. - """ - model_name = registry[_constants._MODEL_NAME_KEY] - setup_args = registry[_constants._SETUP_ARGS_KEY] - if model_name is not None and setup_args is not None: - rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") - rich.pretty.pprint(setup_args) - rich.print() - - def view_registry(self, hide_state_registries: bool = False) -> None: - """Prints summary of the registry. - - Parameters - ---------- - hide_state_registries - If True, prints a shortened summary without details of each state registry. - """ - version = self._registry[_constants._SCVI_VERSION_KEY] - rich.print(f"Anndata setup with scvi-tools version {version}.") - rich.print() - self.view_setup_method_args(self._registry) - - in_colab = "google.colab" in sys.modules - force_jupyter = None if not in_colab else True - console = rich.console.Console(force_jupyter=force_jupyter) - - ss = self._get_summary_stats_from_registry(self._registry) - dr = self._get_data_registry_from_registry(self._registry) - console.print(self._view_summary_stats(ss)) - console.print(self._view_data_registry(dr)) - - if not hide_state_registries: - for field in self.fields: - state_registry = self.get_state_registry(field.registry_key) - t = field.view_state_registry(state_registry) - if t is not None: - console.print(t) diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index 20fbfef293..4e47982a74 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -11,6 +11,8 @@ import scipy.sparse as sp_sparse from anndata import AnnData +from scvi.utils import attrdict + try: # anndata >= 0.10 from anndata.experimental import CSCDataset, CSRDataset @@ -156,6 +158,14 @@ def _set_data_in_registry( setattr(adata, attr_name, attribute) +def _get_summary_stats_from_registry(registry: dict) -> attrdict: + summary_stats = {} + for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): + field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] + summary_stats.update(field_summary_stats) + return attrdict(summary_stats) + + def _verify_and_correct_data_format(adata: AnnData, attr_name: str, attr_key: str | None): """Check data format and correct if necessary. diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 8630ff80d2..5a09cec0f9 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -110,6 +110,7 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): def __init__( self, adata: AnnData, + registry: dict, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -119,24 +120,24 @@ def __init__( linear_classifier: bool = False, **model_kwargs, ): - super().__init__(adata) + super().__init__(adata, registry) scanvae_model_kwargs = dict(model_kwargs) self._set_indices_and_labels() # ignores unlabeled catgegory n_labels = self.summary_stats.n_labels - 1 - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) + n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] + if n_cats_per_cov == 0: + n_cats_per_cov = None n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + use_size_factor_key = self.registry_['setup_args'][f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key'] library_log_means, library_log_vars = None, None - if not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) + if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: + library_log_means, library_log_vars = _init_library_size( + self.adata_manager, n_batch + ) self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -178,7 +179,7 @@ def from_scvi_model( unlabeled_category: str, labels_key: str | None = None, adata: AnnData | None = None, - datamodule: LightningDataModule | None = None, + registry: dict | None = None, **scanvi_kwargs, ): """Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model. @@ -195,8 +196,8 @@ def from_scvi_model( Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. adata AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. - datamodule - LightningDataModule object that has been registered. + registry + Registry of the datamodule used to train scANVI model. scanvi_kwargs kwargs for scANVI model """ @@ -231,7 +232,7 @@ def from_scvi_model( # validate new anndata against old model scvi_model._validate_anndata(adata) - scvi_setup_args = deepcopy(scvi_model.adata_manager.registry[_SETUP_ARGS_KEY]) + scvi_setup_args = deepcopy(scvi_model.registry[_SETUP_ARGS_KEY]) scvi_labels_key = scvi_setup_args["labels_key"] if labels_key is None and scvi_labels_key is None: raise ValueError( @@ -244,8 +245,8 @@ def from_scvi_model( unlabeled_category=unlabeled_category, **scvi_setup_args, ) - scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) - print('TTTT', scanvi_model.registry) + + scanvi_model = cls(adata, scvi_model.registry, **non_kwargs, **kwargs, **scanvi_kwargs) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) scanvi_model.was_pretrained = True @@ -254,7 +255,7 @@ def from_scvi_model( def _set_indices_and_labels(self): """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key self.unlabeled_category_ = labels_state_registry.unlabeled_category @@ -474,12 +475,13 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] # register new fields if the adata is minified - adata_minify_type = _get_adata_minify_type(adata) - if adata_minify_type is not None: - anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) + if adata: + adata_minify_type = _get_adata_minify_type(adata) + if adata_minify_type is not None: + anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index ffc79775a5..7467e61945 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -113,7 +113,7 @@ class SCVI( def __init__( self, adata: AnnData | None = None, - datamodule: LightningDataModule | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -123,7 +123,7 @@ def __init__( latent_distribution: Literal["normal", "ln"] = "normal", **kwargs, ): - super().__init__(adata, datamodule) + super().__init__(adata, registry) self._module_kwargs = { "n_hidden": n_hidden, @@ -146,7 +146,7 @@ def __init__( if n_cats_per_cov == 0: n_cats_per_cov = None n_batch = self.summary_stats.n_batch - use_size_factor_key = self.registry_['setup_args'][f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key'] + use_size_factor_key = self.get_setup_arg(f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key') library_log_means, library_log_vars = None, None if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: library_log_means, library_log_vars = _init_library_size( diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 475ab15661..02607e6f93 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -8,7 +8,6 @@ import pandas as pd import torch from anndata import AnnData -from lightning import LightningDataModule from mudata import MuData from scipy.sparse import csr_matrix @@ -42,7 +41,7 @@ def load_query_data( cls, adata: None | AnnOrMuData = None, reference_model: Union[str, BaseModelClass] = None, - datamodule: None | LightningDataModule = None, + registry: None | dict = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", device: Union[int, str] = "auto", @@ -87,8 +86,8 @@ def load_query_data( """ if reference_model is None: raise ValueError("Please provide a reference model as string or loaded model.") - if adata is None and datamodule is None: - raise ValueError("Please provide either an AnnData or a datamodule.") + if adata is None and registry is None: + raise ValueError("Please provide either an AnnData or a registry dictionary.") _, _, device = parse_device_args( accelerator=accelerator, @@ -139,15 +138,14 @@ def load_query_data( **registry[_SETUP_ARGS_KEY], ) - model = _initialize_model(cls, adata, datamodule, attr_dict) - adata_manager = model.get_anndata_manager(adata, required=True) + model = _initialize_model(cls, adata, registry, attr_dict) - if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry: + if model.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] > 0: raise NotImplementedError( "scArches currently does not support models with extra categorical covariates." ) - version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".") + version_split = model.registry[_constants._SCVI_VERSION_KEY].split(".") if int(version_split[1]) < 8 and int(version_split[0]) == 0: warnings.warn( "Query integration should be performed using models trained with " @@ -374,11 +372,7 @@ def _get_loaded_data(reference_model, device=None, adata=None): else: attr_dict = reference_model._get_user_attributes() attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} - var_names = ( - _get_var_names(reference_model.adata) - if attr_dict["registry_"]["setup_method_name"] != "setup_datamodule" - else _get_var_names(adata) - ) + var_names = reference_model.get_var_names() load_state_dict = deepcopy(reference_model.module.state_dict()) return attr_dict, var_names, load_state_dict diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index d1e6ee0dae..1957187c71 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -3,9 +3,11 @@ import inspect import logging import os +import sys import warnings from abc import ABCMeta, abstractmethod from collections.abc import Sequence +from io import StringIO from uuid import uuid4 import numpy as np @@ -13,20 +15,28 @@ import rich import torch from anndata import AnnData -from lightning import LightningDataModule from mudata import MuData +from rich import box +from rich.console import Console from scvi import REGISTRY_KEYS, settings from scvi._types import AnnOrMuData, MinifiedDataType from scvi.data import AnnDataManager from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( + _FIELD_REGISTRIES_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, + _STATE_REGISTRY_KEY, +) +from scvi.data._utils import ( + _assign_adata_uuid, + _check_if_view, + _get_adata_minify_type, + _get_summary_stats_from_registry, ) -from scvi.data._utils import _assign_adata_uuid, _check_if_view, _get_adata_minify_type from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_device_args from scvi.model.base._constants import SAVE_KEYS @@ -39,6 +49,8 @@ from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +from . import _constants + logger = logging.getLogger(__name__) @@ -87,7 +99,7 @@ class BaseModelClass(metaclass=BaseModelMetaClass): _data_loader_cls = AnnDataLoader - def __init__(self, adata: AnnOrMuData | None = None, datamodule: object | None = None): + def __init__(self, adata: AnnOrMuData | None = None, registry: object | None = None): # check if the given adata is minified and check if the model being created # supports minified-data mode (i.e. inherits from the abstract BaseMinifiedModeModelClass). # If not, raise an error to inform the user of the lack of minified-data functionality @@ -100,21 +112,19 @@ def __init__(self, adata: AnnOrMuData | None = None, datamodule: object | None = self.id = str(uuid4()) # Used for cls._manager_store keys. if adata is not None: self._adata = adata - self._datamodule = None self._adata_manager = self._get_most_recent_anndata_manager(adata, required=True) self._register_manager_for_instance(self.adata_manager) # Suffix registry instance variable with _ to include it when saving the model. - self.registry_ = self._adata_manager.registry - self.summary_stats = self._adata_manager.summary_stats - elif datamodule is not None: + self.registry_ = self._adata_manager._registry + self.summary_stats = _get_summary_stats_from_registry(self.registry_) + elif registry is not None: self._adata = None - self._datamodule = datamodule self._adata_manager = None # Suffix registry instance variable with _ to include it when saving the model. - self.registry_ = datamodule.registry - self.summary_stats = datamodule.summary_stats + self.registry_ = registry + self.summary_stats = _get_summary_stats_from_registry(registry) else: - raise ValueError("adata or datamodule must be provided.") + raise ValueError("adata or registry must be provided.") self.is_trained_ = False self._model_summary_string = "" @@ -128,16 +138,20 @@ def adata(self) -> None | AnnOrMuData: """Data attached to model instance.""" return self._adata - @property - def datamodule(self) -> None | LightningDataModule: - """Data attached to model instance.""" - return self._datamodule - @property def registry(self) -> dict: """Data attached to model instance.""" return self.registry_ + def get_var_names(self, legacy_mudata_format=False) -> dict: + """Variable names of input data.""" + from scvi.model.base._save_load import _get_var_names + if self.adata: + return _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + else: + return self.registry[ + _FIELD_REGISTRIES_KEY]['X'][_STATE_REGISTRY_KEY]['column_names'] + @adata.setter def adata(self, adata: AnnOrMuData): if adata is None: @@ -148,14 +162,6 @@ def adata(self, adata: AnnOrMuData): self.registry_ = self._adata_manager.registry self.summary_stats = self._adata_manager.summary_stats - @datamodule.setter - def datamodule(self, datamodule: LightningDataModule): - if datamodule is None: - raise ValueError("datamodule cannot be None.") - self._datamodule = datamodule - self.registry_ = datamodule.registry - self.summary_stats = datamodule.summary_stats - @property def adata_manager(self) -> AnnDataManager: """Manager instance associated with self.adata.""" @@ -393,6 +399,9 @@ def get_anndata_manager( If True, errors on missing manager. Otherwise, returns None when manager is missing. """ cls = self.__class__ + if not adata: + return None + if _SCVI_UUID_KEY not in adata.uns: if required: raise ValueError( @@ -524,13 +533,20 @@ def _validate_anndata( "Input AnnData not setup with scvi-tools. " + "attempting to transfer AnnData setup" ) - self._register_manager_for_instance(self.adata_manager.transfer_fields(adata)) + self._register_manager_for_instance(self.transfer_fields(adata)) else: # Case where correct AnnDataManager is found, replay registration as necessary. adata_manager.validate() return adata + def transfer_fields(self, adata: AnnOrMuData, **kwargs) -> AnnData: + """Transfer fields from a model to an AnnData object.""" + if self.adata: + return self.adata_manager.transfer_fields(adata, **kwargs) + else: + raise ValueError("Model need to be initialized with AnnData to transfer fields.") + def _check_if_trained(self, warn: bool = True, message: str = _UNTRAINED_WARNING_MESSAGE): """Check if the model is trained. @@ -593,7 +609,7 @@ def _get_user_attributes(self): def _get_init_params(self, locals): """Returns the model init signature with associated passed in values. - Ignores the initial AnnData or DataModule. + Ignores the initial AnnData or Registry. """ init = self.__init__ sig = inspect.signature(init) @@ -605,8 +621,7 @@ def _get_init_params(self, locals): k: v for (k, v) in all_params.items() if not isinstance(v, AnnData) and not isinstance(v, MuData) - and not isinstance(v, LightningDataModule) - and k not in ("adata", "datamodule") + and k not in ("adata", "registry") } # not very efficient but is explicit # separates variable params (**kwargs) from non variable params into two dicts @@ -661,8 +676,6 @@ def save( anndata_write_kwargs Kwargs for :meth:`~anndata.AnnData.write` """ - from scvi.model.base._save_load import _get_var_names - if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: @@ -688,11 +701,7 @@ def save( # save the model state dict and the trainer state dict only model_state_dict = self.module.state_dict() - - if self.adata: - var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) - else: - var_names = self.datamodule.var_names + var_names = self.get_var_names(legacy_mudata_format=legacy_mudata_format) # get all the user attributes user_attributes = self._get_user_attributes() @@ -715,7 +724,6 @@ def load( cls, dir_path: str, adata: AnnOrMuData | None = None, - datamodule: LightningDataModule | None = None, accelerator: str = "auto", device: int | str = "auto", prefix: str | None = None, @@ -732,6 +740,7 @@ def load( It is not necessary to run setup_anndata, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. + If False, will load the model without AnnData. %(param_accelerator)s %(param_device)s prefix @@ -748,7 +757,7 @@ def load( >>> model = ModelClass.load(save_path, adata) >>> model.get_.... """ - load_adata = adata is None and datamodule is None + load_adata = adata is None _, _, device = parse_device_args( accelerator=accelerator, devices=device, @@ -771,17 +780,14 @@ def load( adata = new_adata if new_adata is not None else adata registry = attr_dict.pop("registry_") - if datamodule is not None: - registry['setup_method_name'] = 'setup_datamodule' - else: - registry['setup_method_name'] = 'setup_anndata' + registry['setup_method_name'] = 'setup_anndata' if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") # Calling ``setup_anndata`` method with the original arguments passed into # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. - if adata is not None: + if adata: if _SETUP_ARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " @@ -791,13 +797,13 @@ def load( method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) - model = _initialize_model(cls, adata, datamodule, attr_dict) + model = _initialize_model(cls, adata, registry, attr_dict) model.module.on_load(model) model.module.load_state_dict(model_state_dict) model.to_device(device) model.module.eval() - if adata is not None: + if adata: model._validate_anndata(adata) return model @@ -890,22 +896,6 @@ def setup_anndata( on a model-specific instance of :class:`~scvi.data.AnnDataManager`. """ - @classmethod - @setup_anndata_dsp.dedent - def setup_datamodule( - cls, - datamodule, - *args, - **kwargs, - ): - """%(summary)s. - - Each model class deriving from this class provides parameters to this method - according to its needs. To operate correctly with the model initialization, - the implementation must call :meth:`~scvi.model.base.BaseModelClass.register_manager` - on a model-specific instance of :class:`~scvi.data.AnnDataManager`. - """ - @staticmethod def view_setup_args(dir_path: str, prefix: str | None = None) -> None: """Print args used to setup a saved model. @@ -970,6 +960,152 @@ def view_anndata_setup( ) from err adata_manager.view_registry(hide_state_registries=hide_state_registries) + def view_setup_method_args(self) -> None: + """Prints setup kwargs used to produce a given registry. + + Parameters + ---------- + registry + Registry produced by an AnnDataManager. + """ + model_name = self.registry_[_MODEL_NAME_KEY] + setup_args = self.registry_[_SETUP_ARGS_KEY] + if model_name is not None and setup_args is not None: + rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") + rich.pretty.pprint(setup_args) + rich.print() + + def view_registry(self, hide_state_registries: bool = False) -> None: + """Prints summary of the registry. + + Parameters + ---------- + hide_state_registries + If True, prints a shortened summary without details of each state registry. + """ + version = self.registry_[_SCVI_VERSION_KEY] + rich.print(f"Anndata setup with scvi-tools version {version}.") + rich.print() + self.view_setup_method_args(self._registry) + + in_colab = "google.colab" in sys.modules + force_jupyter = None if not in_colab else True + console = rich.console.Console(force_jupyter=force_jupyter) + + ss = _get_summary_stats_from_registry(self._registry) + dr = self._get_data_registry_from_registry(self._registry) + console.print(self._view_summary_stats(ss)) + console.print(self._view_data_registry(dr)) + + if not hide_state_registries: + for field in self.fields: + state_registry = self.get_state_registry(field.registry_key) + t = field.view_state_registry(state_registry) + if t is not None: + console.print(t) + + def get_state_registry(self, registry_key: str) -> attrdict: + """Returns the state registry for the AnnDataField registered with this instance.""" + return attrdict( + self.registry_[_FIELD_REGISTRIES_KEY][registry_key][ + _STATE_REGISTRY_KEY + ] + ) + + def get_setup_arg(self, setup_arg: str) -> attrdict: + """Returns the string provided to setup of a specific setup_arg.""" + return self.registry_[_SETUP_ARGS_KEY][setup_arg] + + @staticmethod + def _view_summary_stats( + summary_stats: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints summary stats.""" + if not as_markdown: + t = rich.table.Table(title="Summary Statistics") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Summary Stat Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "Value", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + for stat_key, count in summary_stats.items(): + t.add_row(stat_key, str(count)) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + @staticmethod + def _view_data_registry( + data_registry: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints data registry.""" + if not as_markdown: + t = rich.table.Table(title="Data Registry") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Registry Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "scvi-tools Location", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + + for registry_key, data_loc in data_registry.items(): + mod_key = getattr(data_loc, _constants._DR_MOD_KEY, None) + attr_name = data_loc.attr_name + attr_key = data_loc.attr_key + scvi_data_str = "adata" + if mod_key is not None: + scvi_data_str += f".mod['{mod_key}']" + if attr_key is None: + scvi_data_str += f".{attr_name}" + else: + scvi_data_str += f".{attr_name}['{attr_key}']" + t.add_row(registry_key, scvi_data_str) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + def update_setup_method_args(self, setup_method_args: dict): + """Update setup method args. + + Parameters + ---------- + setup_method_args + This is a bit of a misnomer, this is a dict representing kwargs + of the setup method that will be used to update the existing values + in the registry of this instance. + """ + self._registry[_SETUP_ARGS_KEY].update(setup_method_args) class BaseMinifiedModeModelClass(BaseModelClass): """Abstract base class for scvi-tools models that can handle minified data.""" diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index aa00e807f6..02af66efe0 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -97,7 +97,7 @@ def _load_saved_files( return attr_dict, var_names, model_state_dict, adata -def _initialize_model(cls, adata, datamodule, attr_dict): +def _initialize_model(cls, adata, registry, attr_dict): """Helper to initialize a model.""" if "init_params_" not in attr_dict.keys(): raise ValueError( @@ -121,9 +121,6 @@ def _initialize_model(cls, adata, datamodule, attr_dict): kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} non_kwargs.pop("use_cuda") - # adata and datamodule None is stored in the registry - non_kwargs.pop("adata", None) - non_kwargs.pop("datamodule", None) # backwards compat for scANVI if "unlabeled_category" in non_kwargs.keys(): @@ -131,7 +128,10 @@ def _initialize_model(cls, adata, datamodule, attr_dict): if "pretrained_model" in non_kwargs.keys(): non_kwargs.pop("pretrained_model") - model = cls(adata=adata, datamodule=datamodule, **non_kwargs, **kwargs) + if not adata: + adata = None + + model = cls(adata=adata, registry=registry, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) From 18d65a6eb13160e4e721b84c3f07580f74d0325d Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 7 Aug 2024 15:56:16 +0300 Subject: [PATCH 15/53] add changes to tests and some merging with main following custom datamodule / registry big change --- src/scvi/data/_manager.py | 169 +++ src/scvi/dataloaders/_custom_dataloader.py | 1298 ------------------- src/scvi/model/_scanvi.py | 23 +- src/scvi/model/_scvi.py | 97 +- src/scvi/model/base/_base_model.py | 49 +- tests/dataloaders/test_custom_dataloader.py | 72 - tests/model/test_scvi.py | 114 +- 7 files changed, 345 insertions(+), 1477 deletions(-) delete mode 100644 src/scvi/dataloaders/_custom_dataloader.py delete mode 100644 tests/dataloaders/test_custom_dataloader.py diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 8b1f37b846..10d0219041 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -1,14 +1,19 @@ from __future__ import annotations +import sys from collections import defaultdict from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass +from io import StringIO from uuid import uuid4 import numpy as np import pandas as pd +import rich from mudata import MuData +from rich import box +from rich.console import Console from torch.utils.data import Subset import scvi @@ -287,6 +292,18 @@ def validate(self) -> None: adata, self.adata = self.adata, None # Reset self.adata. self.register_fields(adata, self._source_registry, **self._transfer_kwargs) + def update_setup_method_args(self, setup_method_args: dict): + """Update setup method args. + + Parameters + ---------- + setup_method_args + This is a bit of a misnomer, this is a dict representing kwargs + of the setup method that will be used to update the existing values + in the registry of this instance. + """ + self._registry[_constants._SETUP_ARGS_KEY].update(setup_method_args) + @property def adata_uuid(self) -> str: """Returns the UUID for the AnnData object registered with this instance.""" @@ -294,6 +311,11 @@ def adata_uuid(self) -> str: return self._registry[_constants._SCVI_UUID_KEY] + @property + def registry(self) -> dict: + """Returns the top-level registry dictionary for the AnnData object.""" + return self._registry + @property def data_registry(self) -> attrdict: """Returns the data registry for the AnnData object registered with this instance.""" @@ -347,6 +369,20 @@ def _get_data_registry_from_registry(registry: dict) -> attrdict: data_registry[registry_key] = field_data_registry return attrdict(data_registry) + @property + def summary_stats(self) -> attrdict: + """Returns the summary stats for the AnnData object registered with this instance.""" + self._assert_anndata_registered() + return self._get_summary_stats_from_registry(self._registry) + + @staticmethod + def _get_summary_stats_from_registry(registry: dict) -> attrdict: + summary_stats = {} + for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): + field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] + summary_stats.update(field_summary_stats) + return attrdict(summary_stats) + def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: """Returns the object in AnnData associated with the key in the data registry. @@ -368,3 +404,136 @@ def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: return get_anndata_attribute(self.adata, attr_name, attr_key, mod_key=mod_key) + def get_state_registry(self, registry_key: str) -> attrdict: + """Returns the state registry for the AnnDataField registered with this instance.""" + self._assert_anndata_registered() + + return attrdict( + self._registry[_constants._FIELD_REGISTRIES_KEY][registry_key][ + _constants._STATE_REGISTRY_KEY + ] + ) + + @staticmethod + def _view_summary_stats( + summary_stats: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints summary stats.""" + if not as_markdown: + t = rich.table.Table(title="Summary Statistics") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Summary Stat Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "Value", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + for stat_key, count in summary_stats.items(): + t.add_row(stat_key, str(count)) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + @staticmethod + def _view_data_registry( + data_registry: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints data registry.""" + if not as_markdown: + t = rich.table.Table(title="Data Registry") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Registry Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "scvi-tools Location", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + + for registry_key, data_loc in data_registry.items(): + mod_key = getattr(data_loc, _constants._DR_MOD_KEY, None) + attr_name = data_loc.attr_name + attr_key = data_loc.attr_key + scvi_data_str = "adata" + if mod_key is not None: + scvi_data_str += f".mod['{mod_key}']" + if attr_key is None: + scvi_data_str += f".{attr_name}" + else: + scvi_data_str += f".{attr_name}['{attr_key}']" + t.add_row(registry_key, scvi_data_str) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + @staticmethod + def view_setup_method_args(registry: dict) -> None: + """Prints setup kwargs used to produce a given registry. + + Parameters + ---------- + registry + Registry produced by an AnnDataManager. + """ + model_name = registry[_constants._MODEL_NAME_KEY] + setup_args = registry[_constants._SETUP_ARGS_KEY] + if model_name is not None and setup_args is not None: + rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") + rich.pretty.pprint(setup_args) + rich.print() + + def view_registry(self, hide_state_registries: bool = False) -> None: + """Prints summary of the registry. + + Parameters + ---------- + hide_state_registries + If True, prints a shortened summary without details of each state registry. + """ + version = self._registry[_constants._SCVI_VERSION_KEY] + rich.print(f"Anndata setup with scvi-tools version {version}.") + rich.print() + self.view_setup_method_args(self._registry) + + in_colab = "google.colab" in sys.modules + force_jupyter = None if not in_colab else True + console = rich.console.Console(force_jupyter=force_jupyter) + + ss = self._get_summary_stats_from_registry(self._registry) + dr = self._get_data_registry_from_registry(self._registry) + console.print(self._view_summary_stats(ss)) + console.print(self._view_data_registry(dr)) + + if not hide_state_registries: + for field in self.fields: + state_registry = self.get_state_registry(field.registry_key) + t = field.view_state_registry(state_registry) + if t is not None: + console.print(t) diff --git a/src/scvi/dataloaders/_custom_dataloader.py b/src/scvi/dataloaders/_custom_dataloader.py deleted file mode 100644 index b22c697c2a..0000000000 --- a/src/scvi/dataloaders/_custom_dataloader.py +++ /dev/null @@ -1,1298 +0,0 @@ -from __future__ import annotations - -import abc -import gc -import logging -import os -import threading -from collections import deque -from collections.abc import Iterator, Sequence -from concurrent import futures -from concurrent.futures import Future -from contextlib import contextmanager -from datetime import timedelta -from math import ceil -from time import time -from typing import Any, TypeVar - -import numpy as np -import numpy.typing as npt -import pandas as pd -import psutil -import scipy -import tiledbsoma as soma -import torch -import torchdata.datapipes.iter as pipes -from attr import define -from lightning.pytorch import LightningDataModule -from numpy.random import Generator -from scipy import sparse -from sklearn.preprocessing import LabelEncoder -from torch import Tensor -from torch import distributed as dist -from torch.utils.data import DataLoader -from torch.utils.data.dataset import Dataset - -pytorch_logger = logging.getLogger("cellxgene_census.experimental.pytorch") - -# TODO: Rename to reflect the correct order of the Tensors within the tuple: (X, obs) -ObsAndXDatum = tuple[Tensor, Tensor] -"""Return type of ``ExperimentDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of -``X`` matrix row(s).The Tensors are rank 1 if ``batch_size`` is 1, -otherwise the Tensors are rank 2.""" - -util_logger = logging.getLogger("cellxgene_census.experimental.util") - -_T = TypeVar("_T") - - -DEFAULT_TILEDB_CONFIGURATION: dict[str, Any] = { - # https://docs.tiledb.com/main/how-to/configuration#configuration-parameters - "py.init_buffer_bytes": 1 * 1024**3, - "soma.init_buffer_bytes": 1 * 1024**3, - # S3 requests should not be signed, since we want to allow anonymous access - "vfs.s3.no_sign_request": "true", - "vfs.s3.region": "us-west-2", -} - - -def get_default_soma_context( - tiledb_config: dict[str, Any] | None = None, -) -> soma.options.SOMATileDBContext: - """Return a :class:`tiledbsoma.SOMATileDBContext` with sensible defaults that can be further - - customized by the user. The customized context can then be passed to - :func:`cellxgene_census.open_soma` with the ``context`` argument or to - :meth:`somacore.SOMAObject.open` with the ``context`` argument, such as - :meth:`tiledbsoma.Experiment.open`. Use the :meth:`tiledbsoma.SOMATileDBContext.replace` - method on the returned object to customize its settings further. - - Args: - tiledb_config: - A dictionary of TileDB configuration parameters. If specified, the parameters will - override the defaults. If not specified, the default configuration will be returned. - - Returns - ------- - A :class:`tiledbsoma.SOMATileDBContext` object with sensible defaults. - - Examples - -------- - To reduce the amount of memory used by TileDB-SOMA I/O operations: - - .. highlight:: python - .. code-block:: python - - ctx = cellxgene_census.get_default_soma_context( - tiledb_config={ - "py.init_buffer_bytes": 128 * 1024**2, - "soma.init_buffer_bytes": 128 * 1024**2, - } - ) - c = census.open_soma(uri="s3://my-private-bucket/census/soma", context=ctx) - - To access a copy of the Census located in a private bucket that is located in a different - S3 region, use: - - .. highlight:: python - .. code-block:: python - - ctx = cellxgene_census.get_default_soma_context( - tiledb_config={"vfs.s3.no_sign_request": "false", "vfs.s3.region": "us-east-1"} - ) - c = census.open_soma(uri="s3://my-private-bucket/census/soma", context=ctx) - - Lifecycle: - experimental - """ - tiledb_config = dict(DEFAULT_TILEDB_CONFIGURATION, **(tiledb_config or {})) - return soma.options.SOMATileDBContext().replace(tiledb_config=tiledb_config) - - -class _EagerIterator(Iterator[_T]): - def __init__( - self, - iterator: Iterator[_T], - pool: futures.Executor | None = None, - ): - super().__init__() - self.iterator = iterator - self._pool = pool or futures.ThreadPoolExecutor() - self._own_pool = pool is None - self._future: Future[_T] | None = None - self._begin_next() - - def _begin_next(self) -> None: - self._future = self._pool.submit(self.iterator.__next__) - util_logger.debug("Fetching next iterator element, eagerly") - - def __next__(self) -> _T: - try: - assert self._future - res = self._future.result() - self._begin_next() - return res - except StopIteration: - self._cleanup() - raise - - def _cleanup(self) -> None: - util_logger.debug("Cleaning up eager iterator") - if self._own_pool: - self._pool.shutdown() - - def __del__(self) -> None: - # Ensure the threadpool is cleaned up in the case where the - # iterator is not exhausted. For more information on __del__: - # https://docs.python.org/3/reference/datamodel.html#object.__del__ - self._cleanup() - super_del = getattr(super(), "__del__", lambda: None) - super_del() - - -class _EagerBufferedIterator(Iterator[_T]): - def __init__( - self, - iterator: Iterator[_T], - max_pending: int = 1, - pool: futures.Executor | None = None, - ): - super().__init__() - self.iterator = iterator - self.max_pending = max_pending - self._pool = pool or futures.ThreadPoolExecutor() - self._own_pool = pool is None - self._pending_results: deque[futures.Future[_T]] = deque() - self._lock = threading.Lock() - self._begin_next() - - def __next__(self) -> _T: - try: - res = self._pending_results[0].result() - self._pending_results.popleft() - self._begin_next() - return res - except StopIteration: - self._cleanup() - raise - - def _begin_next(self) -> None: - def _fut_done(fut: futures.Future[_T]) -> None: - util_logger.debug("Finished fetching next iterator element, eagerly") - if fut.exception() is None: - self._begin_next() - - with self._lock: - not_running = len(self._pending_results) == 0 or self._pending_results[-1].done() - if len(self._pending_results) < self.max_pending and not_running: - _future = self._pool.submit(self.iterator.__next__) - util_logger.debug("Fetching next iterator element, eagerly") - _future.add_done_callback(_fut_done) - self._pending_results.append(_future) - assert len(self._pending_results) <= self.max_pending - - def _cleanup(self) -> None: - util_logger.debug("Cleaning up eager iterator") - if self._own_pool: - self._pool.shutdown() - - def __del__(self) -> None: - # Ensure the threadpool is cleaned up in the case where the - # iterator is not exhausted. For more information on __del__: - # https://docs.python.org/3/reference/datamodel.html#object.__del__ - self._cleanup() - super_del = getattr(super(), "__del__", lambda: None) - super_del() - - -class Encoder(abc.ABC): - """Base class for obs encoders. - - To define a custom encoder, two methods must be implemented: - - - ``register``: defines how the encoder will be fitted to the data. - - ``transform``: defines how the encoder will be applied to the data - in order to create an obs_tensor. - - See the implementation of ``DefaultEncoder`` for an example. - """ - - @abc.abstractmethod - def register(self, obs: pd.DataFrame) -> None: - """Register the encoder with obs.""" - pass - - @abc.abstractmethod - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """Transform the obs DataFrame into a DataFrame of encoded values.""" - pass - - @property - def name(self) -> str: - return self.__class__.__name__ - - -class DefaultEncoder(Encoder): - """Default encoder based on LabelEncoder.""" - - def __init__(self, col: str) -> None: - self._encoder = LabelEncoder() - self.col = col - - def register(self, obs: pd.DataFrame) -> None: - self._encoder.fit(obs[self.col].unique()) - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - return self._encoder.transform(df[self.col]) # type: ignore - - @property - def name(self) -> str: - return self.col - - @property - def classes_(self): # type: ignore - return self._encoder.classes_ - - -@define -class _SOMAChunk: - """Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the - - respective rows from the ``X`` matrix. - - Lifecycle: - experimental - """ - - obs: pd.DataFrame - X: scipy.sparse.spmatrix - stats: Stats - - def __len__(self) -> int: - return len(self.obs) - - -Encoders = dict[str, LabelEncoder] -"""A dictionary of ``LabelEncoder``s keyed by the ``obs`` column name.""" - - -@define -class Stats: - """Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API. This is useful - - for assessing the read throughput of SOMA data. - - Lifecycle: - experimental - """ - - n_obs: int = 0 - """The total number of obs rows retrieved""" - - nnz: int = 0 - """The total number of values retrieved""" - - elapsed: int = 0 - """The total elapsed time in seconds for retrieving all batches""" - - n_soma_chunks: int = 0 - """The number of chunks retrieved""" - - def __str__(self) -> str: - return ( - f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " - f"elapsed={timedelta(seconds=self.elapsed)}" - ) - - def __add__(self, other: Stats) -> Stats: - self.n_obs += other.n_obs - self.nnz += other.nnz - self.elapsed += other.elapsed - self.n_soma_chunks += other.n_soma_chunks - return self - - -@contextmanager -def _open_experiment( - uri: str, - aws_region: str | None = None, -) -> soma.Experiment: - """Internal method for opening a SOMA ``Experiment`` as a context manager.""" - context = get_default_soma_context().replace( - tiledb_config={"vfs.s3.region": aws_region} if aws_region else {} - ) - - with soma.Experiment.open(uri, context=context) as exp: - yield exp - - -class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]): - """Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class, - - not intended for public use. - """ - - X: soma.SparseNDArray - """A handle to the full X data of the SOMA ``Experiment``""" - - obs_joinids_chunks_iter: Iterator[npt.NDArray[np.int64]] - - var_joinids: npt.NDArray[np.int64] - """The ``var`` joinids to be retrieved from the SOMA ``Experiment``""" - - def __init__( - self, - obs: soma.DataFrame, - X: soma.SparseNDArray, - obs_column_names: Sequence[str], - obs_joinids_chunked: list[npt.NDArray[np.int64]], - var_joinids: npt.NDArray[np.int64], - shuffle_chunk_count: int | None = None, - shuffle_rng: Generator | None = None, - ): - self.obs = obs - self.X = X - self.obs_column_names = obs_column_names - if shuffle_chunk_count: - assert shuffle_rng is not None - - # At the start of this step, `obs_joinids_chunked` is a list of one dimensional - # numpy arrays. Each numpy array corresponds to a chunk of contiguous rows in `obs`. - # Critically, `obs_joinids_chunked` is randomly ordered where each chunk is - # from a random section of `obs`. - # We then take `shuffle_chunk_count` of these in order, concatenate them into - # a larger numpy array and shuffle this larger numpy array. - # The result is again a list of numpy arrays. - self.obs_joinids_chunks_iter = ( - shuffle_rng.permutation(np.concatenate(grouped_chunks)) - for grouped_chunks in list_split(obs_joinids_chunked, shuffle_chunk_count) - ) - else: - self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) - self.var_joinids = var_joinids - self.shuffle_chunk_count = shuffle_chunk_count - - def __next__(self) -> _SOMAChunk: - pytorch_logger.debug("Retrieving next SOMA chunk...") - start_time = time() - - # If no more chunks to iterate through, raise StopIteration, as all iterators - # do when at end - obs_joinids_chunk = next(self.obs_joinids_chunks_iter) - - obs_batch = ( - self.obs.read( - coords=(obs_joinids_chunk,), - column_names=self.obs_column_names, - ) - .concat() - .to_pandas() - .set_index("soma_joinid") - ) - assert obs_batch.shape[0] == obs_joinids_chunk.shape[0] - - # handle case of empty result (first batch has 0 rows) - if len(obs_batch) == 0: - raise StopIteration - - # reorder obs rows to match obs_joinids_chunk ordering, which may be shuffled - obs_batch = obs_batch.reindex(obs_joinids_chunk, copy=False) - - # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse - # matrix, but the blockwise iteration feature is not used (block_size is set to retrieve - # the chunk as a single block) - scipy_iter = ( - self.X.read(coords=(obs_joinids_chunk, self.var_joinids)) - .blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) - .scipy(compress=True) - ) - X_batch, _ = next(scipy_iter) - assert obs_batch.shape[0] == X_batch.shape[0] - - stats = Stats() - stats.n_obs += X_batch.shape[0] - stats.nnz += X_batch.nnz - stats.elapsed += int(time() - start_time) - stats.n_soma_chunks += 1 - - pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}") - return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) - - -def list_split(arr_list: list[Any], sublist_len: int) -> list[list[Any]]: - """Splits a python list into a list of sublists where each sublist is of size `sublist_len`. - - TODO: Replace with `itertools.batched` when Python 3.12 becomes the minimum supported version. - """ - i = 0 - result = [] - while i < len(arr_list): - if (i + sublist_len) >= len(arr_list): - result.append(arr_list[i:]) - else: - result.append(arr_list[i : i + sublist_len]) - - i += sublist_len - - return result - - -def run_gc() -> tuple[tuple[Any, Any, Any], tuple[Any, Any, Any]]: - proc = psutil.Process(os.getpid()) - - pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() - gc.collect() - post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() - - pytorch_logger.debug(f"gc: pre={pre_gc}") - pytorch_logger.debug(f"gc: post={post_gc}") - - return pre_gc, post_gc - - -class _ObsAndXIterator(Iterator[ObsAndXDatum]): - """Iterates through a set of ``obs`` and corresponding ``X`` rows, where the rows to be - - returned are specified by the ``obs_tables_iter`` argument. For the specified ``obs` rows, - the corresponding ``X`` data is loaded and joined together. It is returned from this iterator - as 2-tuples of ``X`` and obs Tensors. - - Internally manages the retrieval of data in SOMA-sized chunks, fetching the next chunk of SOMA - data as needed. Supports fetching the data in an eager manner, where the next SOMA chunk is - fetched while the current chunk is being read. This is an internal class, not intended for - public use. - """ - - soma_chunk_iter: _SOMAChunk | None - """The iterator for SOMA chunks of paired obs and X data""" - - soma_chunk: _SOMAChunk | None - """The current SOMA chunk of obs and X data""" - - i: int = -1 - """Index into current obs ``SOMA`` chunk""" - - def __init__( - self, - obs: soma.DataFrame, - X: soma.SparseNDArray, - obs_column_names: Sequence[str], - obs_joinids_chunked: list[npt.NDArray[np.int64]], - var_joinids: npt.NDArray[np.int64], - batch_size: int, - encoders: list[Encoder], - stats: Stats, - return_sparse_X: bool, - use_eager_fetch: bool, - shuffle_chunk_count: int | None = None, - shuffle_rng: Generator | None = None, - ) -> None: - self.soma_chunk_iter = _ObsAndXSOMAIterator( - obs, - X, - obs_column_names, - obs_joinids_chunked, - var_joinids, - shuffle_chunk_count, - shuffle_rng, - ) - if use_eager_fetch: - self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) - self.soma_chunk = None - self.var_joinids = var_joinids - self.batch_size = batch_size - self.return_sparse_X = return_sparse_X - self.encoders = encoders - self.stats = stats - self.max_process_mem_usage_bytes = 0 - self.X_dtype = X.schema[2].type.to_pandas_dtype() - - def __next__(self) -> ObsAndXDatum: - """Read the next torch batch, possibly across multiple soma chunks.""" - obs: pd.DataFrame = pd.DataFrame() - X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype) - - while len(obs) < self.batch_size: - try: - obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs)) - obs = pd.concat([obs, obs_partial], axis=0) - X = sparse.vstack([X, X_partial]) - except StopIteration: - break - - if len(obs) == 0: - raise StopIteration - - obs_encoded = pd.DataFrame() - - for enc in self.encoders: - obs_encoded[enc.name] = enc.transform(obs) - - # `to_numpy()` avoids copying the numpy array data - obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) - - if not self.return_sparse_X: - X_tensor = torch.from_numpy(X.todense()) - else: - coo = X.tocoo() - - X_tensor = torch.sparse_coo_tensor( - # Note: The `np.array` seems unnecessary, but PyTorch warns bare array - # is "extremely slow" - indices=torch.from_numpy(np.array([coo.row, coo.col])), - values=coo.data, - size=coo.shape, - ) - - if self.batch_size == 1: - X_tensor = X_tensor[0] - obs_tensor = obs_tensor[0] - - return X_tensor, obs_tensor - - def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: - """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size - - batch whose size may contain fewer rows than the requested ``batch_size``. This can happen - when the remaining rows in the current SOMA chunk are fewer than the requested - ``batch_size``. - """ - if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)): - # GC memory from previous soma_chunk - self.soma_chunk = None - mem_info = run_gc() - self.max_process_mem_usage_bytes = max( - self.max_process_mem_usage_bytes, mem_info[0][0].uss - ) - - self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) - self.stats += self.soma_chunk.stats - self.i = 0 - - pytorch_logger.debug(f"Retrieved SOMA chunk totals: {self.stats}") - - obs_batch = self.soma_chunk.obs - X_batch = self.soma_chunk.X - - safe_batch_size = min(batch_size, len(obs_batch) - self.i) - slice_ = slice(self.i, self.i + safe_batch_size) - assert slice_.stop <= obs_batch.shape[0] - - obs_rows = obs_batch.iloc[slice_] - assert obs_rows.index.is_unique - assert safe_batch_size == obs_rows.shape[0] - - X_csr_scipy = X_batch[slice_] - assert obs_rows.shape[0] == X_csr_scipy.shape[0] - - self.i += safe_batch_size - - return obs_rows, X_csr_scipy - - -class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore - r"""An :class:`torchdata.datapipes.iter.IterDataPipe` that reads ``obs`` and ``X`` data from a - - :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` - axes. Provides an iterator over these data when the object is passed to Python's built-in - ``iter`` function. - - >>> for batch in iter(ExperimentDataPipe(...)): - X_batch, y_batch = batch - - The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are - returned in each iteration. If the ``batch_size`` is 1, then each Tensor will have rank 1: - - >>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data - tensor([2415, 0, 0], dtype=torch.int64)) # obs data, encoded - - For larger ``batch_size`` values, the returned Tensors will have rank 2: - - >>> DataLoader(..., batch_size=3, ...): - (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch - [0., 0., 0., 0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), - tensor([[2415, 0, 0], # obs batch - [2416, 0, 4], - [2417, 0, 3]], dtype=torch.int64)) - - The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or - sparse :class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, - this will reduce memory usage. - - The ``obs_column_names`` parameter determines the data columns that are returned in the - ``obs`` Tensor. The first element is always the ``soma_joinid`` of the ``obs`` - :class:`pandas.DataFrame` (or, equivalently, the ``soma_dim_0`` of the ``X`` matrix). - The remaining elements are the ``obs`` columns specified by ``obs_column_names``, - and string-typed columns are encoded as integer values. If needed, these values can be decoded - by obtaining the encoder for a given ``obs`` column name and calling its ``inverse_transform`` - method: - - >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) - - Lifecycle: - experimental - """ - - _initialized: bool - - _obs_joinids: npt.NDArray[np.int64] | None - - _var_joinids: npt.NDArray[np.int64] | None - - _encoders: list[Encoder] - - _stats: Stats - - _shuffle_rng: Generator | None - - # TODO: Consider adding another convenience method wrapper to construct this object whose - # signature is more closely aligned with get_anndata() params - # (i.e. "exploded" AxisQuery params). - def __init__( - self, - experiment: soma.Experiment, - measurement_name: str = "RNA", - X_name: str = "raw", - obs_query: soma.AxisQuery | None = None, - var_query: soma.AxisQuery | None = None, - obs_column_names: Sequence[str] = (), - batch_size: int = 1, - shuffle: bool = True, - seed: int | None = None, - return_sparse_X: bool = False, - soma_chunk_size: int | None = 64, - use_eager_fetch: bool = True, - encoders: list[Encoder] | None = None, - shuffle_chunk_count: int | None = 2000, - ) -> None: - r"""Construct a new ``ExperimentDataPipe``. - - Args: - experiment: - The :class:`tiledbsoma.Experiment` from which to read data. - measurement_name: - The name of the :class:`tiledbsoma.Measurement` to read. Defaults to ``"RNA"``. - X_name: - The name of the X layer to read. Defaults to ``"raw"``. - obs_query: - The query used to filter along the ``obs`` axis. If not specified, all ``obs`` and - ``X`` data will be returned, which can be very large. - var_query: - The query used to filter along the ``var`` axis. If not specified, all ``var`` - columns (genes/features) will be returned. - obs_column_names: - The names of the ``obs`` columns to return. The ``soma_joinid`` index "column" does - not need to be specified and will always be returned. If not specified, only the - ``soma_joinid`` will be returned. - batch_size: - The number of rows of ``obs`` and ``X`` data to return in each iteration. Defaults - to ``1``. A value of ``1`` will result in :class:`torch.Tensor` of rank 1 being - returns (a single row); larger values will result in :class:`torch.Tensor`\ s of - rank 2 (multiple rows). - shuffle: - Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. - For performance reasons, shuffling is not performed globally across all rows, but - rather in chunks. More specifically, we select ``shuffle_chunk_count`` - non-contiguous chunks across all the observations - in the query, concatenate the chunks and shuffle the associated observations. - The randomness of the shuffling is therefore determined by the - (``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have - been determined to yield a good trade-off between randomness and performance. - Further tuning may be required for different type of models. Note that memory usage - is correlated to the product ``soma_chunk_size * shuffle_chunk_count``. - seed: - The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be - specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure - data partitions are disjoint across worker processes. - return_sparse_X: - Controls whether the ``X`` data is returned as a dense or sparse - :class:`torch.Tensor`. As ``X`` data is very sparse, setting this to ``True`` will - reduce memory usage, if the model supports use of sparse :class:`torch.Tensor`\ s. - Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still experimental - in PyTorch. - soma_chunk_size: - The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This - impacts two aspects of this class's behavior: 1) The maximum memory utilization, - with larger values providing better read performance, but also requiring more - memory; 2) The granularity of the global shuffling step (see ``shuffle`` parameter - for details). The default value of 64 works well in conjunction with the default - ``shuffle_chunk_count`` value. - use_eager_fetch: - Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously - fetched SOMA chunk is made available for processing via the iterator. This allows - network (or filesystem) requests to be made in parallel with client-side processing - of the SOMA data, potentially improving overall performance at the cost of - doubling memory utilization. Defaults to ``True``. - shuffle_chunk_count: - The number of contiguous blocks (chunks) of rows sampled to then concatenate - and shuffle. Larger numbers correspond to more randomness per training batch. - If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``. - encoders: - Specify custom encoders to be used. If not specified, a LabelEncoder will be - created and used for each column in ``obs_column_names``. If specified, only - columns for which an encoder has been registered will be returned in the - ``obs`` tensor. - - Lifecycle: - experimental - """ - self.exp_uri = experiment.uri - self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region") - self.measurement_name = measurement_name - self.layer_name = X_name - self.obs_query = obs_query - self.var_query = var_query - self.obs_column_names = obs_column_names - self.batch_size = batch_size - self.return_sparse_X = return_sparse_X - self.soma_chunk_size = soma_chunk_size - self.use_eager_fetch = use_eager_fetch - self._stats = Stats() - self._custom_encoders = encoders - self._encoders = [] - self._obs_joinids = None - self._var_joinids = None - self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None - self._shuffle_rng = np.random.default_rng(seed) if shuffle else None - self._initialized = False - - if "soma_joinid" not in self.obs_column_names: - self.obs_column_names = ["soma_joinid", *self.obs_column_names] - - def _init(self) -> None: - if self._initialized: - return - - pytorch_logger.debug("Initializing ExperimentDataPipe") - - with _open_experiment(self.exp_uri, self.aws_region) as exp: - query = exp.axis_query( - measurement_name=self.measurement_name, - obs_query=self.obs_query, - var_query=self.var_query, - ) - - # The to_numpy() call is a workaround for a possible bug in TileDB-SOMA: - # https://github.com/single-cell-data/TileDB-SOMA/issues/1456 - self._obs_joinids = query.obs_joinids().to_numpy() - self._var_joinids = query.var_joinids().to_numpy() - - self._encoders = self._build_obs_encoders(query) - - self._initialized = True - - @staticmethod - def _subset_ids_to_partition( - ids_chunked: list[npt.NDArray[np.int64]], - partition_index: int, - num_partitions: int, - ) -> list[npt.NDArray[np.int64]]: - """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), - - based upon the current process's distributed rank and world size. - """ - # subset to a single partition - # typing does not reflect that is actually a list of 2D NDArrays - partition_indices = np.array_split(range(len(ids_chunked)), num_partitions) - partition = [ids_chunked[i] for i in partition_indices[partition_index]] - - if pytorch_logger.isEnabledFor(logging.DEBUG) and len(partition) > 0: - pytorch_logger.debug( - f"Process {os.getpid()} handling partition {partition_index + 1} " - f"of {num_partitions}, partition_size={sum([len(chunk) for chunk in partition])}" - ) - - return partition - - @staticmethod - def _compute_partitions( - loader_partition: int, - loader_partitions: int, - dist_partition: int, - num_dist_partitions: int, - ) -> tuple[int, int]: - # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload - total_partitions = num_dist_partitions * loader_partitions - partition = dist_partition * loader_partitions + loader_partition - return partition, total_partitions - - def __iter__(self) -> Iterator[ObsAndXDatum]: - self._init() - assert self._obs_joinids is not None - assert self._var_joinids is not None - - if self.soma_chunk_size is None: - # set soma_chunk_size to utilize ~1 GiB of RAM per SOMA chunk; assumes 95% X data - # sparsity, 8 bytes for the X value and 8 bytes for the sparse matrix indices, - # and a 100% working memory overhead (2x). - X_row_memory_size = 0.05 * len(self._var_joinids) * 8 * 3 * 2 - self.soma_chunk_size = int((1 * 1024**3) / X_row_memory_size) - pytorch_logger.debug(f"Using {self.soma_chunk_size=}") - - if ( - self.return_sparse_X - and torch.utils.data.get_worker_info() - and torch.utils.data.get_worker_info().num_workers > 0 - ): - raise NotImplementedError( - "torch does not work with sparse tensors in multi-processing mode " - "(see https://github.com/pytorch/pytorch/issues/20248)" - ) - - # chunk the obs joinids into batches of size soma_chunk_size - obs_joinids_chunked = self._chunk_ids(self._obs_joinids, self.soma_chunk_size) - - # globally shuffle the chunks, if requested - if self._shuffle_rng: - self._shuffle_rng.shuffle(obs_joinids_chunked) - - # subset to a single partition, as needed for distributed training and multi-processing - # data loading - worker_info = torch.utils.data.get_worker_info() - partition, partitions = self._compute_partitions( - loader_partition=worker_info.id if worker_info else 0, - loader_partitions=worker_info.num_workers if worker_info else 1, - dist_partition=dist.get_rank() if dist.is_initialized() else 0, - num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, - ) - obs_joinids_chunked_partition: list[npt.NDArray[np.int64]] = self._subset_ids_to_partition( - obs_joinids_chunked, partition, partitions - ) - - with _open_experiment(self.exp_uri, self.aws_region) as exp: - obs_and_x_iter = _ObsAndXIterator( - obs=exp.obs, - X=exp.ms[self.measurement_name].X[self.layer_name], - obs_column_names=self.obs_column_names, - obs_joinids_chunked=obs_joinids_chunked_partition, - var_joinids=self._var_joinids, - batch_size=self.batch_size, - encoders=self._encoders, - stats=self._stats, - return_sparse_X=self.return_sparse_X, - use_eager_fetch=self.use_eager_fetch, - shuffle_rng=self._shuffle_rng, - shuffle_chunk_count=self._shuffle_chunk_count, - ) - - yield from obs_and_x_iter - - pytorch_logger.debug( - "max process memory usage=" - f"{obs_and_x_iter.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" - ) - - @staticmethod - def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> list[npt.NDArray[np.int64]]: - num_chunks = max(1, ceil(len(ids) / chunk_size)) - pytorch_logger.debug( - f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}" - ) - return np.array_split(ids, num_chunks) - - def __len__(self) -> int: - self._init() - assert self._obs_joinids is not None - - return len(self._obs_joinids) - - def __getitem__(self, index: int) -> ObsAndXDatum: - raise NotImplementedError("IterDataPipe can only be iterated") - - def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> list[Encoder]: - pytorch_logger.debug("Initializing encoders") - - encoders = [] - obs = query.obs(column_names=self.obs_column_names).concat().to_pandas() - - if self._custom_encoders: - # Register all the custom encoders with obs - for enc in self._custom_encoders: - enc.register(obs) - encoders.append(enc) - else: - # Create one DefaultEncoder for each column, and register it with obs - for col in self.obs_column_names: - if obs[col].dtype in [object]: - enc = DefaultEncoder(col) - enc.register(obs) - encoders.append(enc) - - return encoders - - # TODO: This does not work in multiprocessing mode, as child process's stats are not collected - def stats(self) -> Stats: - """Get data loading stats for this - - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - - Returns - ------- - The :class:`cellxgene_census.experimental.ml.pytorch.Stats` object for this - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - - Lifecycle: - experimental - """ - return self._stats - - @property - def shape(self) -> tuple[int, int]: - """Get the shape of the data that will be returned by this - - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - This is the number of obs (cell) and var (feature) counts in the returned data. If used in - multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` - instantiated with num_workers > 0), the obs (cell) count will reflect - the size of the partition of the data assigned to the active process. - - Returns - ------- - A 2-tuple of ``int``s, for obs and var counts, respectively. - - Lifecycle: - experimental - """ - self._init() - assert self._obs_joinids is not None - assert self._var_joinids is not None - - return len(self._obs_joinids), len(self._var_joinids) - - @property - def obs_encoders(self) -> Encoders: - """Returns a dictionary of :class:`sklearn.preprocessing.LabelEncoder` objects, keyed on - - ``obs`` column names, which were used to encode the ``obs`` column values. - - These encoders can be used to decode the encoded values as follows: - - >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) - - Returns - ------- - A ``dict[str, LabelEncoder]``, mapping column names to :class:`sklearn.preprocessing. - LabelEncoder` objects. - """ - self._init() - assert self._encoders is not None - - return {enc.name: enc for enc in self._encoders} - - -# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling -def _collate_noop(x: Any) -> Any: - return x - - -# TODO: Move into somacore.ExperimentAxisQuery -def experiment_dataloader( - datapipe: pipes.IterDataPipe, - num_workers: int = 0, - **dataloader_kwargs: Any, -) -> DataLoader: - """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely - - instantiate a :class:`torch.utils.data.DataLoader` that works with - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`, since some of the - :class:`torch.utils.data.DataLoader` constructor parameters are not applicable when using a - :class:`torchdata.datapipes.iter.IterDataPipe` (``shuffle``, ``batch_size``, ``sampler``, - ``batch_sampler``,``collate_fn``). - - Args: - datapipe: - An :class:`torchdata.datapipes.iter.IterDataPipe`, which can be an - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe` or any other - :class:`torchdata.datapipes.iter.IterDataPipe` that has been chained to the - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - num_workers: - Number of worker processes to use for data loading. If ``0``, data will be loaded in - the main process. - **dataloader_kwargs: - Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` - constructor, except for ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, - and ``collate_fn``, which are not supported when using - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - - Returns - ------- - A :class:`torch.utils.data.DataLoader`. - - Raises - ------ - ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, - or ``collate_fn`` params are passed as keyword arguments. - - Lifecycle: - experimental - """ - unsupported_dataloader_args = [ - "shuffle", - "batch_size", - "sampler", - "batch_sampler", - "collate_fn", - ] - if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): - raise ValueError( - f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported" - ) - - if num_workers > 0: - _init_multiprocessing() - - return DataLoader( - datapipe, - batch_size=None, # batching is handled by our ExperimentDataPipe - num_workers=num_workers, - # avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches - collate_fn=_collate_noop, - # shuffling is handled by our ExperimentDataPipe - shuffle=False, - **dataloader_kwargs, - ) - - -def _init_multiprocessing() -> None: - """Ensures use of "spawn" for starting child processes with multiprocessing. - - Forked processes are known to be problematic: - https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks - Also, CUDA does not support forked child processes: - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing - - """ - torch.multiprocessing.set_start_method("fork", force=True) - orig_start_method = torch.multiprocessing.get_start_method() - if orig_start_method != "spawn": - if orig_start_method: - pytorch_logger.warning( - "switching torch multiprocessing start method from " - f'"{torch.multiprocessing.get_start_method()}" to "spawn"' - ) - torch.multiprocessing.set_start_method("spawn", force=True) - - -class BatchEncoder(Encoder): - """Concatenates and encodes several columns.""" - - def __init__(self, cols: list[str]): - self.cols = cols - from sklearn.preprocessing import LabelEncoder - - self._encoder = LabelEncoder() - - def transform(self, df: pd.DataFrame): - import functools - - arr = functools.reduce(lambda a, b: a + b, [df[c].astype(str) for c in self.cols]) - return self._encoder.transform(arr) - - def register(self, obs: pd.DataFrame): - import functools - - arr = functools.reduce(lambda a, b: a + b, [obs[c].astype(str) for c in self.cols]) - self._encoder.fit(arr.unique()) - - @property - def name(self) -> str: - return "batch" - - @property - def classes_(self): - return self._encoder.classes_ - - -class CensusSCVIDataModule(LightningDataModule): - """Lightning data module for CxG Census. - - Parameters - ---------- - *args - Positional arguments passed to - :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - batch_keys - List of obs column names concatenated to form the batch column. - train_size - Fraction of data to use for training. - split_seed - Seed for data split. - dataloader_kwargs - Keyword arguments passed into - :func:`~cellxgene_census.experimental.ml.pytorch.experiment_dataloader`. - **kwargs - Additional keyword arguments passed into - :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. Must not include - ``obs_column_names``. - """ - - _TRAIN_KEY = "train" - _VALIDATION_KEY = "validation" - - def __init__( - self, - *args, - batch_keys: list[str] | None = None, - train_size: float | None = None, - split_seed: int | None = None, - dataloader_kwargs: dict[str, any] | None = None, - **kwargs, - ): - super().__init__() - self.datapipe_args = args - self.datapipe_kwargs = kwargs - self.batch_keys = batch_keys - self.train_size = train_size - self.split_seed = split_seed - self.dataloader_kwargs = dataloader_kwargs or {} - - @property - def batch_keys(self) -> list[str]: - """List of obs column names concatenated to form the batch column.""" - if not hasattr(self, "_batch_keys"): - raise AttributeError("`batch_keys` not set.") - return self._batch_keys - - @batch_keys.setter - def batch_keys(self, value: list[str] | None): - if value is None or not isinstance(value, list): - raise ValueError("`batch_keys` must be a list of strings.") - self._batch_keys = value - - @property - def obs_column_names(self) -> list[str]: - """Passed to :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`.""" - if hasattr(self, "_obs_column_names"): - return self._obs_column_names - - obs_column_names = [] - if self.batch_keys is not None: - obs_column_names.extend(self.batch_keys) - - self._obs_column_names = obs_column_names - return self._obs_column_names - - @property - def split_seed(self) -> int: - """Seed for data split.""" - if not hasattr(self, "_split_seed"): - raise AttributeError("`split_seed` not set.") - return self._split_seed - - @split_seed.setter - def split_seed(self, value: int | None): - if value is not None and not isinstance(value, int): - raise ValueError("`split_seed` must be an integer.") - self._split_seed = value or 0 - - @property - def train_size(self) -> float: - """Fraction of data to use for training.""" - if not hasattr(self, "_train_size"): - raise AttributeError("`train_size` not set.") - return self._train_size - - @train_size.setter - def train_size(self, value: float | None): - if value is not None and not isinstance(value, float): - raise ValueError("`train_size` must be a float.") - elif value is not None and (value < 0.0 or value > 1.0): - raise ValueError("`train_size` must be between 0.0 and 1.0.") - self._train_size = value or 1.0 - - @property - def validation_size(self) -> float: - """Fraction of data to use for validation.""" - if not hasattr(self, "_train_size"): - raise AttributeError("`validation_size` not available.") - return 1.0 - self.train_size - - @property - def weights(self) -> dict[str, float]: - """Passed to :meth:`~cellxgene_census.experimental.ml.ExperimentDataPipe.random_split`.""" - if not hasattr(self, "_weights"): - self._weights = {self._TRAIN_KEY: self.train_size} - if self.validation_size > 0.0: - self._weights[self._VALIDATION_KEY] = self.validation_size - return self._weights - - @property - def datapipe(self) -> ExperimentDataPipe: - """Experiment data pipe.""" - if not hasattr(self, "_datapipe"): - encoder = BatchEncoder(self.obs_column_names) - self._datapipe = ExperimentDataPipe( - *self.datapipe_args, - obs_column_names=self.obs_column_names, - encoders=[encoder], - **self.datapipe_kwargs, - ) - return self._datapipe - - def setup(self, stage: str | None = None): - """Set up the train and validation data pipes.""" - datapipes = self.datapipe.random_split(weights=self.weights, seed=self.split_seed) - self._train_datapipe = datapipes[0] - if self.validation_size > 0.0: - self._validation_datapipe = datapipes[1] - else: - self._validation_datapipe = None - - def train_dataloader(self): - """Training data loader.""" - return experiment_dataloader(self._train_datapipe, **self.dataloader_kwargs) - - def val_dataloader(self): - """Validation data loader.""" - if self._validation_datapipe is not None: - return experiment_dataloader(self._validation_datapipe, **self.dataloader_kwargs) - - @property - def n_obs(self) -> int: - """Number of observations in the query. - - Necessary in scvi-tools to compute a heuristic of ``max_epochs``. - """ - return self.datapipe.shape[0] - - @property - def n_vars(self) -> int: - """Number of features in the query. - - Necessary in scvi-tools to initialize the actual layers in the model. - - """ - return self.datapipe.shape[1] - - @property - def n_batch(self) -> int: - """ - Number of unique batches (after concatenation of ``batch_keys``). Necessary in scvi-tools - - so that the model knows how to one-hot encode batches. - - """ - return self.get_n_classes("batch") - - def get_n_classes(self, key: str) -> int: - """Return the number of classes for a given obs column.""" - return len(self.datapipe.obs_encoders[key].classes_) - - def on_before_batch_transfer( - self, - batch: tuple[torch.Tensor, torch.Tensor], - dataloader_idx: int, - ) -> dict[str, torch.Tensor | None]: - """Format the datapipe output with registry keys for scvi-tools.""" - X, obs = batch - - X_KEY: str = "X" - BATCH_KEY: str = "batch" - LABELS_KEY: str = "labels" - - return { - X_KEY: X, - BATCH_KEY: obs, - LABELS_KEY: None, - } diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 5a09cec0f9..86fb340927 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -110,7 +110,7 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): def __init__( self, adata: AnnData, - registry: dict, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -127,17 +127,20 @@ def __init__( # ignores unlabeled catgegory n_labels = self.summary_stats.n_labels - 1 - n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] - if n_cats_per_cov == 0: - n_cats_per_cov = None + # n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] + # if n_cats_per_cov == 0: + # n_cats_per_cov = None + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) n_batch = self.summary_stats.n_batch - use_size_factor_key = self.registry_['setup_args'][f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key'] + use_size_factor_key = self.registry_["setup_args"][f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key"] library_log_means, library_log_vars = None, None if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -479,7 +482,9 @@ def setup_anndata( adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 7467e61945..dd0b50bd31 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -1,16 +1,15 @@ from __future__ import annotations import logging -import warnings from typing import Literal import numpy as np from anndata import AnnData -from lightning import LightningDataModule -from scvi import REGISTRY_KEYS, settings +import scvi +from scvi import REGISTRY_KEYS from scvi._types import MinifiedDataType -from scvi.data import AnnDataManager +from scvi.data import AnnDataManager, _constants from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( @@ -27,7 +26,7 @@ from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin from scvi.model.utils import get_minified_adata_scrna from scvi.module import VAE -from scvi.utils import setup_anndata_dsp +from scvi.utils import attrdict, setup_anndata_dsp from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin @@ -142,16 +141,20 @@ def __init__( f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) - n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] - if n_cats_per_cov == 0: - n_cats_per_cov = None + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + # n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] + # if n_cats_per_cov == 0: + # n_cats_per_cov = None + n_batch = self.summary_stats.n_batch - use_size_factor_key = self.get_setup_arg(f'{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key') + use_size_factor_key = self.get_setup_arg(f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key") library_log_means, library_log_vars = None, None if self.adata is not None and not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( n_input=self.summary_stats.n_vars, n_batch=n_batch, @@ -215,17 +218,22 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - # adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() - # adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] - # pprint(adata_manager.registry) + + @staticmethod + def _get_summary_stats_from_registry(registry: dict) -> attrdict: + summary_stats = {} + for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): + field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] + summary_stats.update(field_summary_stats) + return attrdict(summary_stats) @classmethod @setup_anndata_dsp.dedent def setup_datamodule( cls, - datamodule, + datamodule, # TODO: what to put here? layer: str | None = None, - batch_key: str | None = None, + batch_key: list[str] | None = None, labels_key: str | None = None, size_factor_key: str | None = None, categorical_covariate_keys: list[str] | None = None, @@ -244,8 +252,59 @@ def setup_datamodule( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ - - pass + datamodule.registry = { + "scvi_version": scvi.__version__, + "model_name": "SCVI", + "setup_args": { + "layer": layer, + "batch_key": batch_key, + "labels_key": labels_key, + "size_factor_key": size_factor_key, + "categorical_covariate_keys": categorical_covariate_keys, + "continuous_covariate_keys": continuous_covariate_keys, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": datamodule.n_obs, + "n_vars": datamodule.n_vars, + "column_names": datamodule.vars, + }, + "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": datamodule.datapipe.obs_encoders["batch"].classes_, + "original_key": "batch", + }, + "summary_stats": {"n_batch": datamodule.n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": np.array([0]), + "original_key": "_scvi_labels", + }, + "summary_stats": {"n_labels": 1}, + }, + "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } + datamodule.summary_stats = cls._get_summary_stats_from_registry(datamodule.registry) + datamodule.var_names = [str(i) for i in datamodule.vars] @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 1957187c71..a63db55d96 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -27,6 +27,7 @@ _FIELD_REGISTRIES_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, + _SCVI_VERSION_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, _STATE_REGISTRY_KEY, @@ -146,11 +147,11 @@ def registry(self) -> dict: def get_var_names(self, legacy_mudata_format=False) -> dict: """Variable names of input data.""" from scvi.model.base._save_load import _get_var_names + if self.adata: return _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) else: - return self.registry[ - _FIELD_REGISTRIES_KEY]['X'][_STATE_REGISTRY_KEY]['column_names'] + return self.registry[_FIELD_REGISTRIES_KEY]["X"][_STATE_REGISTRY_KEY]["column_names"] @adata.setter def adata(self, adata: AnnOrMuData): @@ -290,22 +291,22 @@ def data_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: else: return self._adata_manager.get_from_registry(registry_key) - def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: - """Returns the object in AnnData associated with the key in the data registry. - - Parameters - ---------- - registry_key - key of object to get from ``self.data_registry`` - - Returns - ------- - The requested data. - """ - if not self.adata: - raise ValueError("self.adata is None. Please registry AnnData object.") - else: - return self._adata_manager.get_from_registry(registry_key) + # def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: + # """Returns the object in AnnData associated with the key in the data registry. + # + # Parameters + # ---------- + # registry_key + # key of object to get from ``self.data_registry`` + # + # Returns + # ------- + # The requested data. + # """ + # if not self.adata: + # raise ValueError("self.adata is None. Please registry AnnData object.") + # else: + # return self._adata_manager.get_from_registry(registry_key) def deregister_manager(self, adata: AnnData | None = None): """Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`. @@ -620,7 +621,8 @@ def _get_init_params(self, locals): all_params = { k: v for (k, v) in all_params.items() - if not isinstance(v, AnnData) and not isinstance(v, MuData) + if not isinstance(v, AnnData) + and not isinstance(v, MuData) and k not in ("adata", "registry") } # not very efficient but is explicit @@ -780,7 +782,7 @@ def load( adata = new_adata if new_adata is not None else adata registry = attr_dict.pop("registry_") - registry['setup_method_name'] = 'setup_anndata' + registry["setup_method_name"] = "setup_anndata" if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") @@ -1006,11 +1008,7 @@ def view_registry(self, hide_state_registries: bool = False) -> None: def get_state_registry(self, registry_key: str) -> attrdict: """Returns the state registry for the AnnDataField registered with this instance.""" - return attrdict( - self.registry_[_FIELD_REGISTRIES_KEY][registry_key][ - _STATE_REGISTRY_KEY - ] - ) + return attrdict(self.registry_[_FIELD_REGISTRIES_KEY][registry_key][_STATE_REGISTRY_KEY]) def get_setup_arg(self, setup_arg: str) -> attrdict: """Returns the string provided to setup of a specific setup_arg.""" @@ -1107,6 +1105,7 @@ def update_setup_method_args(self, setup_method_args: dict): """ self._registry[_SETUP_ARGS_KEY].update(setup_method_args) + class BaseMinifiedModeModelClass(BaseModelClass): """Abstract base class for scvi-tools models that can handle minified data.""" diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py deleted file mode 100644 index 8ef4b038af..0000000000 --- a/tests/dataloaders/test_custom_dataloader.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -import os -from pprint import pprint - -import numpy as np -import scanpy as sc - -import scvi -from scvi.data import _constants, synthetic_iid -from scvi.model import SCVI - -# We will now create the SCVI model object: -# Its parameters: -n_layers = 1 -n_latent = 10 -batch_size = 1024 -train_size = 0.9 -max_epochs = 1 - - -# COMAPRE TO THE ORIGINAL METHOD!!! - use the same data!!! -# We first create a registry using the orignal way of anndata in order to compare and add -# what is missing -adata = synthetic_iid() -adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) -SCVI.setup_anndata( - adata, - batch_key="batch", - labels_key="labels", - size_factor_key="size_factor", -) -# -model_orig = SCVI(adata, n_latent=n_latent) -model_orig.train(1, check_val_every_n_epoch=1, train_size=0.5) - -# Saving the model -save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory() -model_dir = os.path.join(save_dir, "scvi_orig_model") -model_orig.save(model_dir, overwrite=True) - -# Loading the model (just as a compariosn) -model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata) - -# when loading from disk -scvi.model.SCVI.prepare_query_anndata(adata, model_dir) -# O -scvi.model.SCVI.prepare_query_anndata(adata, model_orig_loaded) - -# Obtaining model outputs -SCVI_LATENT_KEY = "X_scVI" -latent = model_orig.get_latent_representation() -adata.obsm[SCVI_LATENT_KEY] = latent -# latent.shape - -# You can see all necessary entries and the structure at -adata_manager = model_orig.adata_manager -model_orig.view_anndata_setup(hide_state_registries=True) -# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict() -adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] - -pprint(adata_manager.registry) - -# Plot UMAP and save the figure for later check -sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") -sc.tl.umap(adata, neighbors_key="scvi") -sc.pl.umap(adata, color="dataset_id", title="SCVI") - -# Now return and add all the registry stuff that we will need - -# Now add the missing stuff from the current CZI implemenation in order for us to have the exact -# same steps like the original way (except than setup_anndata) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 6000846770..7b0eee7c7a 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -941,33 +941,36 @@ def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): datamodule.n_vars = adata.n_vars datamodule.n_batch = n_batches - model = SCVI(n_latent=5) - assert model._module_init_on_train - assert model.module is None - - # cannot infer default max_epochs without n_obs set in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) - - # must pass in datamodule if not initialized with adata - with pytest.raises(ValueError): - model.train() - - model.train(max_epochs=1, datamodule=datamodule) - - # must set n_obs for defaulting max_epochs - datamodule.n_obs = 100_000_000 # large number for fewer default epochs - model.train(datamodule=datamodule) - - model = SCVI(adata, n_latent=5) - # Add an example for external custom dataloader? - assert not model._module_init_on_train - assert model.module is not None - assert hasattr(model, "adata") - - # initialized with adata, cannot pass in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) + with pytest.raises(ValueError) as excinfo: + SCVI(n_latent=5) + assert str(excinfo.value) == "adata or registry must be provided." + # model = SCVI(n_latent=5) + # assert model._module_init_on_train + # assert model.module is None + # + # # cannot infer default max_epochs without n_obs set in datamodule + # with pytest.raises(ValueError): + # model.train(datamodule=datamodule) + # + # # must pass in datamodule if not initialized with adata + # with pytest.raises(ValueError): + # model.train() + # + # model.train(max_epochs=1, datamodule=datamodule) + # + # # must set n_obs for defaulting max_epochs + # datamodule.n_obs = 100_000_000 # large number for fewer default epochs + # model.train(datamodule=datamodule) + # + # model = SCVI(adata, n_latent=5) + # # Add an example for external custom dataloader? + # assert not model._module_init_on_train + # assert model.module is not None + # assert hasattr(model, "adata") + # + # # initialized with adata, cannot pass in datamodule + # with pytest.raises(ValueError): + # model.train(datamodule=datamodule) def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int = 5): @@ -990,32 +993,35 @@ def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int datamodule.n_vars = adata.n_vars datamodule.n_batch = n_batches - model = SCVI(n_latent=5) - assert model._module_init_on_train - assert model.module is None - - # cannot infer default max_epochs without n_obs set in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) - - # must pass in datamodule if not initialized with adata - with pytest.raises(ValueError): - model.train() - - model.train(max_epochs=1, datamodule=datamodule) - - # must set n_obs for defaulting max_epochs - datamodule.n_obs = 100_000_000 # large number for fewer default epochs - model.train(datamodule=datamodule) - - model = SCVI(adata, n_latent=5) - assert not model._module_init_on_train - assert model.module is not None - assert hasattr(model, "adata") - - # initialized with adata, cannot pass in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) + with pytest.raises(ValueError) as excinfo: + SCVI(n_latent=5) + assert str(excinfo.value) == "adata or registry must be provided." + # model = SCVI(n_latent=5) + # assert model._module_init_on_train + # assert model.module is None + # + # # cannot infer default max_epochs without n_obs set in datamodule + # with pytest.raises(ValueError): + # model.train(datamodule=datamodule) + # + # # must pass in datamodule if not initialized with adata + # with pytest.raises(ValueError): + # model.train() + # + # model.train(max_epochs=1, datamodule=datamodule) + # + # # must set n_obs for defaulting max_epochs + # datamodule.n_obs = 100_000_000 # large number for fewer default epochs + # model.train(datamodule=datamodule) + # + # model = SCVI(adata, n_latent=5) + # assert not model._module_init_on_train + # assert model.module is not None + # assert hasattr(model, "adata") + # + # # initialized with adata, cannot pass in datamodule + # with pytest.raises(ValueError): + # model.train(datamodule=datamodule) @pytest.mark.parametrize("embedding_dim", [5, 10]) @@ -1088,7 +1094,7 @@ def test_scvi_train_custom_dataloader(n_latent: int = 5): model = SCVI(adata, n_latent=n_latent) model.train(max_epochs=1) dataloader = model._make_data_loader(adata) - SCVI.setup_datamodule(dataloader) + # SCVI.setup_datamodule(dataloader) # continue from here. Datamodule will always require to pass it into all downstream functions. model.train(max_epochs=1, datamodule=dataloader) _ = model.get_elbo(dataloader=dataloader) From 4fe3ee13240700da9e302eb2c8da06be499935a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 12:58:07 +0000 Subject: [PATCH 16/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/base/_archesmixin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 02607e6f93..cac253723b 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -17,7 +17,6 @@ from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME from scvi.model._utils import parse_device_args from scvi.model.base._save_load import ( - _get_var_names, _initialize_model, _load_saved_files, _validate_var_names, @@ -140,7 +139,7 @@ def load_query_data( model = _initialize_model(cls, adata, registry, attr_dict) - if model.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] > 0: + if model.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] > 0: raise NotImplementedError( "scArches currently does not support models with extra categorical covariates." ) From 1110966a4df00a7971eca972abec53ec1bce2799 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 7 Aug 2024 16:09:53 +0300 Subject: [PATCH 17/53] just put the cutom dataloder2 test under remarks so hook tests will run, we will later adjust this file --- tests/dataloaders/test_custom_dataloader2.py | 486 +++++++------------ 1 file changed, 173 insertions(+), 313 deletions(-) diff --git a/tests/dataloaders/test_custom_dataloader2.py b/tests/dataloaders/test_custom_dataloader2.py index 9c3b625d6c..9c8f98716e 100644 --- a/tests/dataloaders/test_custom_dataloader2.py +++ b/tests/dataloaders/test_custom_dataloader2.py @@ -1,329 +1,189 @@ -from __future__ import annotations - -import sys - -sys.path.insert(0, "/Users/orikr/Documents/cellxgene-census/api/python/cellxgene_census/src") -sys.path.insert(0, "src") - -import cellxgene_census -import numpy as np -import tiledbsoma as soma -from cellxgene_census.experimental.ml.datamodule import ( - CensusSCVIDataModule, # WE RAN FROM LOCAL LIB -) -from cellxgene_census.experimental.pp import highly_variable_genes - -import scvi -from scvi.data import _constants, synthetic_iid -from scvi.utils import attrdict - -# cellxgene_census.__file__, scvi.__file__ - -# We will now create the SCVI model object: -# Its parameters: -n_layers = 1 -n_latent = 10 -batch_size = 1024 -train_size = 0.9 -max_epochs = 1 - -# We have to create a registry without setup_anndata that contains the same elements -# The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK -# need to pass here new object of registry taht contains everything we will need - -# First lets see CELLXGENE example using pytorch loaders implemented now in our repo -census = cellxgene_census.open_soma(census_version="stable") -experiment_name = "mus_musculus" -obs_value_filter = 'is_primary_data == True and tissue_general in ["spleen"] and nnz >= 300' -top_n_hvg = 8000 -hvg_batch = ["assay", "suspension_type"] -# THIS WILL TAKE FEW MINUTES TO RUN! -query = census["census_data"][experiment_name].axis_query( - measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter) -) -hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch) -hv = hvgs_df.highly_variable -hv_idx = hv[hv].index - -# Now load the custom data module CZI did that now exists in our db -# (and we will later want to elaborate with more info from our original anndata registry) -# This thing is done by the user in any form they want -datamodule = CensusSCVIDataModule( - census["census_data"][experiment_name], - measurement_name="RNA", - X_name="raw", - obs_query=soma.AxisQuery(value_filter=obs_value_filter), - var_query=soma.AxisQuery(coords=(list(hv_idx),)), - batch_size=1024, - shuffle=True, - batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], - dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, -) - -datamodule.vars = hv_idx - - -def _get_summary_stats_from_registry(registry: dict) -> attrdict: - summary_stats = {} - for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): - field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] - summary_stats.update(field_summary_stats) - return attrdict(summary_stats) - - -def setup_datamodule(datamodule: CensusSCVIDataModule): - datamodule.registry = { - "scvi_version": scvi.__version__, - "model_name": "SCVI", - "setup_args": { - "layer": None, - "batch_key": "batch", - "labels_key": None, - "size_factor_key": None, - "categorical_covariate_keys": None, - "continuous_covariate_keys": None, - }, - "field_registries": { - "X": { - "data_registry": {"attr_name": "X", "attr_key": None}, - "state_registry": { - "n_obs": datamodule.n_obs, - "n_vars": datamodule.n_vars, - "column_names": datamodule.vars, - }, - "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, - }, - "batch": { - "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, - "state_registry": { - "categorical_mapping": datamodule.datapipe.obs_encoders["batch"].classes_, - "original_key": "batch", - }, - "summary_stats": {"n_batch": datamodule.n_batch}, - }, - "labels": { - "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, - "state_registry": { - "categorical_mapping": np.array([0]), - "original_key": "_scvi_labels", - }, - "summary_stats": {"n_labels": 1}, - }, - "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, - "extra_categorical_covs": { - "data_registry": {}, - "state_registry": {}, - "summary_stats": {"n_extra_categorical_covs": 0}, - }, - "extra_continuous_covs": { - "data_registry": {}, - "state_registry": {}, - "summary_stats": {"n_extra_continuous_covs": 0}, - }, - }, - "setup_method_name": "setup_datamodule", - } - datamodule.summary_stats = _get_summary_stats_from_registry(datamodule.registry) - datamodule.var_names = [str(i) for i in datamodule.vars] - - -# This is a new func to implement (Implemented Above but we need in our code base as well) -# will take a bit of time to end -setup_datamodule(datamodule) - -# The next part is the same as test_scvi_train_custom_dataloader - -adata = synthetic_iid() -scvi.model.SCVI.setup_anndata(adata, batch_key="batch") -model = scvi.model.SCVI(adata, n_latent=10) -model.train(max_epochs=1) -dataloader = model._make_data_loader(adata) -_ = model.get_elbo(dataloader=dataloader) -_ = model.get_marginal_ll(dataloader=dataloader) -_ = model.get_reconstruction_error(dataloader=dataloader) -_ = model.get_latent_representation(dataloader=dataloader) - -# ORI I broke the code here also for standard models. Please first fix this. - it is fixed -scvi.model.SCVI.prepare_query_anndata(adata, model) -query_model = scvi.model.SCVI.load_query_data(adata, model) - -# We will now create the SCVI model object: -model_census = scvi.model.SCVI( - datamodule=datamodule, - n_layers=n_layers, - n_latent=n_latent, - gene_likelihood="nb", - encode_covariates=False, -) - -# The CZI data module is a refined data module while SCVI is a lighting datamodule -# Altough this is only 1 epoch it will take few mins on local machine -model_census.train( - datamodule=datamodule, - max_epochs=max_epochs, - batch_size=batch_size, - train_size=train_size, - early_stopping=False, -) - -# We can now save the trained model. As of the current writing date (June 2024), -# scvi-tools doesn't support saving a model that wasn't generated through an AnnData loader, -# so we'll use some custom code: -# model_state_dict = model_census.module.state_dict() -# var_names = hv_idx.to_numpy() -# user_attributes = model_census._get_user_attributes() -# user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} -model_census.save("dataloader_model2", overwrite=True) - -# We are now turning this data module back to AnnData -adata = cellxgene_census.get_anndata( - census, - organism=experiment_name, - obs_value_filter=obs_value_filter, -) - -adata = adata[:, datamodule.vars].copy() - -adata.obs.head() - -# ORI Replace this with the function to generate batch key used in the datamodule. -# "12967895-3d58-4e93-be2c-4e1bcf4388d510x 5' v1cellHCA_Mou_3" -adata.obs["batch"] = ("batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str)).astype( - "category" -) -# adata.var_names = 'gene_'+adata.var_names #not sure we need it - -# We will now load the model back and use it to generate cell embeddings (the latent space), -# which can then be used for further analysis. Note that we still need to use some custom code for -# loading the model, which includes loading the parameters from the `attr_dict` node stored in -# the model. - -model_census2 = scvi.model.SCVI.load("dataloader_model2", datamodule=datamodule) -model_census2.setup_anndata(adata, batch_key="batch") -# model_census2.adata = deepcopy(adata) -# ORI Works when loading from disk -scvi.model.SCVI.prepare_query_anndata(adata, "dataloader_model2") -# ORI This one still needs to be fixed. -scvi.model.SCVI.prepare_query_anndata(adata, model_census2) - -# ORI Should work when setting up the AnnData correctly. scANVI with DataModule is not yet -# supported as DataModule can't take a labels_key. -scanvae = scvi.model.SCANVI.from_scvi_model( - model_census2, - adata=adata, - unlabeled_category="Unknown", - labels_key="cell_type", -) - -# ORI - check it should work with a model initialized with AnnData. See below not fully working yet -model_census3 = scvi.model.SCVI.load("dataloader_model2", adata=adata) - -scvi.model.SCVI.prepare_query_anndata(adata, "dataloader_model2") -query_model = scvi.model.SCVI.load_query_data(adata, "dataloader_model2") - -scvi.model.SCVI.prepare_query_anndata(adata, model_census3) -query_model = scvi.model.SCVI.load_query_data(adata, model_census3) - -# with open("model.pt", "rb") as f: -# torch_model = torch.load(f) -# -# adict = torch_model["attr_dict"] -# params = adict["init_params_"]["non_kwargs"] -# -# n_batch = adict["n_batch"] -# n_extra_categorical_covs = adict["n_extra_categorical_covs"] -# n_extra_continuous_covs = adict["n_extra_continuous_covs"] -# n_labels = adict["n_labels"] -# n_vars = adict["n_vars"] -# -# latent_distribution = params["latent_distribution"] -# dispersion = params["dispersion"] -# n_hidden = params["n_hidden"] -# dropout_rate = params["dropout_rate"] -# gene_likelihood = params["gene_likelihood"] -# -# model = scvi.model.SCVI( -# n_layers=params["n_layers"], -# n_latent=params["n_latent"], -# gene_likelihood=params["gene_likelihood"], -# encode_covariates=False, -# ) +# from __future__ import annotations # -# module = model._module_cls( -# n_input=n_vars, -# n_batch=n_batch, -# n_labels=n_labels, -# n_continuous_cov=n_extra_continuous_covs, -# n_cats_per_cov=None, -# n_hidden=n_hidden, -# n_latent=n_latent, -# n_layers=n_layers, -# dropout_rate=dropout_rate, -# dispersion=dispersion, -# gene_likelihood=gene_likelihood, -# latent_distribution=latent_distribution, -# ) -# model.module = module -# -# model.module.load_state_dict(torch_model["model_state_dict"]) -# -# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# -# model.to_device(device) -# model.module.eval() -# model.is_trained = True - -# We will now generate the cell embeddings for this model, using the `get_latent_representation` -# function available in scvi-tools. -# We can use another instance of the `ExperimentDataPipe` for the forward pass, so we don't need -# to load the whole dataset in memory. - -# # Needs to have shuffle=False for inference -# datamodule_inference = CensusSCVIDataModule( +# import sys +# +# sys.path.insert(0, "/Users/orikr/Documents/cellxgene-census/api/python/cellxgene_census/src") +# sys.path.insert(0, "src") +# +# import cellxgene_census +# import numpy as np +# import tiledbsoma as soma +# from cellxgene_census.experimental.ml.datamodule import ( +# CensusSCVIDataModule, # WE RAN FROM LOCAL LIB +# ) +# from cellxgene_census.experimental.pp import highly_variable_genes +# +# import scvi +# from scvi.data import synthetic_iid +# from scvi.model import SCVI +# +# # cellxgene_census.__file__, scvi.__file__ +# +# # We will now create the SCVI model object: +# # Its parameters: +# n_layers = 1 +# n_latent = 10 +# batch_size = 1024 +# train_size = 0.9 +# max_epochs = 1 +# +# # We have to create a registry without setup_anndata that contains the same elements +# # The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK +# # need to pass here new object of registry taht contains everything we will need +# +# # First lets see CELLXGENE example using pytorch loaders implemented now in our repo +# census = cellxgene_census.open_soma(census_version="stable") +# experiment_name = "mus_musculus" +# obs_value_filter = 'is_primary_data == True and tissue_general in ["spleen"] and nnz >= 300' +# top_n_hvg = 800 +# hvg_batch = ["assay", "suspension_type"] +# +# # THIS WILL TAKE FEW MINUTES TO RUN! +# query = census["census_data"][experiment_name].axis_query( +# measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter) +# ) +# hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch) +# hv = hvgs_df.highly_variable +# hv_idx = hv[hv].index +# hv_idx = np.arange(100) # just randomly select smaller number of indices +# +# # Now load the custom data module CZI did that now exists in our db +# # (and we will later want to elaborate with more info from our original anndata registry) +# # This thing is done by the user in any form they want +# datamodule = CensusSCVIDataModule( # census["census_data"][experiment_name], # measurement_name="RNA", # X_name="raw", # obs_query=soma.AxisQuery(value_filter=obs_value_filter), # var_query=soma.AxisQuery(coords=(list(hv_idx),)), # batch_size=1024, -# shuffle=False, -# soma_chunk_size=50_000, +# shuffle=True, # batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], # dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, # ) # -# # We can simply feed the datapipe to `get_latent_representation` to obtain the embeddings - -# # will take a while -# datapipe = datamodule_inference.datapipe -# dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) -# mapped_dataloader = ( -# datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader -# ) -# latent = model.get_latent_representation(dataloader=mapped_dataloader) -# emb_idx = datapipe._obs_joinids -# -# # We will now take a look at the UMAP for the generated embedding -# # (will be later comapred to what we got) -# adata = cellxgene_census.get_anndata( -# census, -# organism=experiment_name, -# obs_value_filter=obs_value_filter, -# ) -# obs_soma_joinids = adata.obs["soma_joinid"] -# obs_indexer = pd.Index(emb_idx) -# idx = obs_indexer.get_indexer(obs_soma_joinids) -# # Reindexing is necessary to ensure that the cells in the embedding match the -# # ones in the anndata object. -# adata.obsm["scvi"] = latent[idx] +# datamodule.vars = hv_idx # -# # Plot UMAP and save the figure for later check -# sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") -# sc.tl.umap(adata, neighbors_key="scvi") -# sc.pl.umap(adata, color="dataset_id", title="SCVI") # +# # The next part is the same as test_scvi_train_custom_dataloader +# def test_scvi_train_custom_dataloader(n_latent: int = 10): +# adata = synthetic_iid() +# scvi.model.SCVI.setup_anndata(adata, batch_key="batch") +# model = scvi.model.SCVI(adata, n_latent=n_latent) +# model.train(max_epochs=1) +# dataloader = model._make_data_loader(adata) +# _ = model.get_elbo(dataloader=dataloader) +# _ = model.get_marginal_ll(dataloader=dataloader) +# _ = model.get_reconstruction_error(dataloader=dataloader) +# _ = model.get_latent_representation(dataloader=dataloader) # -# # Now return and add all the registry stuff that we will need +# scvi.model.SCVI.prepare_query_anndata(adata, model) +# query_model = scvi.model.SCVI.load_query_data(adata, model) # # -# # Now add the missing stuff from the current CZI implemenation in order for us to have the exact -# # same steps like the original way (except than setup_anndata) +# def test_scvi_train_custom_datamodule(datamodule=datamodule): +# # This is a new func to implement +# # will take a bit of time to end +# SCVI.setup_datamodule(datamodule) +# +# # We will now create the SCVI model object from custom data module: +# model_census = scvi.model.SCVI( +# registry=datamodule.registry, +# n_layers=n_layers, +# n_latent=n_latent, +# gene_likelihood="nb", +# encode_covariates=False, +# ) +# +# # The CZI data module is a refined data module while SCVI is a lighting datamodule +# # Altough this is only 1 epoch it will take few mins on local machine +# model_census.train( +# datamodule=datamodule, +# max_epochs=max_epochs, +# batch_size=batch_size, +# train_size=train_size, +# early_stopping=False, +# ) +# +# # We can now save the trained model. As of the current writing date (June 2024), +# # scvi-tools doesn't support saving a model that wasn't generated through an AnnData loader, +# # so we'll use some custom code: +# # model_state_dict = model_census.module.state_dict() +# # var_names = hv_idx.to_numpy() +# # user_attributes = model_census._get_user_attributes() +# # user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} +# model_census.save("dataloader_model2", overwrite=True) +# model_census_loaded = scvi.model.SCVI.load("dataloader_model2", adata=False) +# +# +# def test_scvi_train_custom_datamodule_from_loaded_model(datamodule=datamodule): +# model_census_loaded = scvi.model.SCVI.load("dataloader_model2", adata=False) +# +# # see if can train from loaded models +# model_census_loaded.train( +# datamodule=datamodule, +# max_epochs=max_epochs, +# batch_size=batch_size, +# train_size=train_size, +# early_stopping=False, +# ) +# +# +# def test_scvi_get_anndata_load_anndatamodule_from_custom_datamodule(datamodule=datamodule): +# # we will perform here several task that deals with transforming custom data module to +# # the regular ann datamodule - will take a bit of time +# adata = cellxgene_census.get_anndata( +# census, organism=experiment_name, obs_value_filter=obs_value_filter, var_coords=hv_idx +# ) +# +# adata = adata[:, datamodule.vars].copy() +# +# adata.obs.head() +# +# # ORI Replace this with the function to generate batch key used in the datamodule. +# # "12967895-3d58-4e93-be2c-4e1bcf4388d510x 5' v1cellHCA_Mou_3" +# adata.obs["batch"] = ( +# "batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str) +# ).astype("category") +# # adata.var_names = 'gene_'+adata.var_names #not sure we need it +# +# # We will now load the model back and use it to generate cell embeddings (the latent space), +# # which can then be used for further analysis. Note that we still need to use some custom +# # code for loading the model, which includes loading the parameters from the `attr_dict` node +# # stored in the model. +# +# # loading and setupanndata +# model_census2 = scvi.model.SCVI.load("dataloader_model2", adata=False) +# model_census2.setup_anndata(adata, batch_key="batch") +# # model_census2.adata = deepcopy(adata) +# +# # ORI Works when loading from disk +# scvi.model.SCVI.prepare_query_anndata( +# adata, "dataloader_model2", return_reference_var_names=True +# ) +# # ORI This one still needs to be fixed. +# scvi.model.SCVI.prepare_query_anndata(adata, model_census2, return_reference_var_names=True) +# query_model = scvi.model.SCVI.load_query_data( +# registry=datamodule.registry, reference_model="dataloader_model2" +# ) +# +# # ORI Should work when setting up the AnnData correctly. scANVI with DataModule is not yet +# # supported as DataModule can't take a labels_key. +# scanvae = scvi.model.SCANVI.from_scvi_model( +# model_census2, +# adata=adata, +# unlabeled_category="Unknown", +# labels_key="cell_type", +# ) +# +# # ORI - check it should work with a model initialized with AnnData. +# # See below not fully working yet +# model_census3 = scvi.model.SCVI.load("dataloader_model2", adata=adata) +# +# scvi.model.SCVI.prepare_query_anndata( +# adata, "dataloader_model2", return_reference_var_names=True +# ) +# query_model = scvi.model.SCVI.load_query_data(adata, "dataloader_model2") +# +# scvi.model.SCVI.prepare_query_anndata(adata, model_census3) +# query_model = scvi.model.SCVI.load_query_data(adata, model_census3) From 7972bdcefa2adb88a49ab482a74c9bc993fd3012 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 7 Aug 2024 17:41:11 +0300 Subject: [PATCH 18/53] fixes --- src/scvi/model/_scanvi.py | 21 +++++++----- src/scvi/model/_scvi.py | 19 ++++++----- src/scvi/model/base/_base_model.py | 3 ++ tests/model/test_scvi.py | 53 ------------------------------ 4 files changed, 26 insertions(+), 70 deletions(-) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 86fb340927..6b8e82e376 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -109,7 +109,7 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, @@ -127,14 +127,17 @@ def __init__( # ignores unlabeled catgegory n_labels = self.summary_stats.n_labels - 1 - # n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] - # if n_cats_per_cov == 0: - # n_cats_per_cov = None - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) + if adata is not None: + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + else: + # custom datamodule + n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] + if n_cats_per_cov == 0: + n_cats_per_cov = None n_batch = self.summary_stats.n_batch use_size_factor_key = self.registry_["setup_args"][f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key"] diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index dd0b50bd31..437d070089 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -141,14 +141,17 @@ def __init__( f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) - # n_cats_per_cov = self.summary_stats[f'n_{REGISTRY_KEYS.CAT_COVS_KEY}'] - # if n_cats_per_cov == 0: - # n_cats_per_cov = None + if adata is not None: + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + else: + # custom datamodule + n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] + if n_cats_per_cov == 0: + n_cats_per_cov = None n_batch = self.summary_stats.n_batch use_size_factor_key = self.get_setup_arg(f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key") diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index a63db55d96..738e35a4e1 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -124,6 +124,9 @@ def __init__(self, adata: AnnOrMuData | None = None, registry: object | None = N # Suffix registry instance variable with _ to include it when saving the model. self.registry_ = registry self.summary_stats = _get_summary_stats_from_registry(registry) + elif self.__class__.__name__ == "GIMVI": + # note some models do accept empty registry/adata (e.g: gimvi) + pass else: raise ValueError("adata or registry must be provided.") diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 7b0eee7c7a..30adcc3a5b 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -944,33 +944,6 @@ def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): with pytest.raises(ValueError) as excinfo: SCVI(n_latent=5) assert str(excinfo.value) == "adata or registry must be provided." - # model = SCVI(n_latent=5) - # assert model._module_init_on_train - # assert model.module is None - # - # # cannot infer default max_epochs without n_obs set in datamodule - # with pytest.raises(ValueError): - # model.train(datamodule=datamodule) - # - # # must pass in datamodule if not initialized with adata - # with pytest.raises(ValueError): - # model.train() - # - # model.train(max_epochs=1, datamodule=datamodule) - # - # # must set n_obs for defaulting max_epochs - # datamodule.n_obs = 100_000_000 # large number for fewer default epochs - # model.train(datamodule=datamodule) - # - # model = SCVI(adata, n_latent=5) - # # Add an example for external custom dataloader? - # assert not model._module_init_on_train - # assert model.module is not None - # assert hasattr(model, "adata") - # - # # initialized with adata, cannot pass in datamodule - # with pytest.raises(ValueError): - # model.train(datamodule=datamodule) def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int = 5): @@ -996,32 +969,6 @@ def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int with pytest.raises(ValueError) as excinfo: SCVI(n_latent=5) assert str(excinfo.value) == "adata or registry must be provided." - # model = SCVI(n_latent=5) - # assert model._module_init_on_train - # assert model.module is None - # - # # cannot infer default max_epochs without n_obs set in datamodule - # with pytest.raises(ValueError): - # model.train(datamodule=datamodule) - # - # # must pass in datamodule if not initialized with adata - # with pytest.raises(ValueError): - # model.train() - # - # model.train(max_epochs=1, datamodule=datamodule) - # - # # must set n_obs for defaulting max_epochs - # datamodule.n_obs = 100_000_000 # large number for fewer default epochs - # model.train(datamodule=datamodule) - # - # model = SCVI(adata, n_latent=5) - # assert not model._module_init_on_train - # assert model.module is not None - # assert hasattr(model, "adata") - # - # # initialized with adata, cannot pass in datamodule - # with pytest.raises(ValueError): - # model.train(datamodule=datamodule) @pytest.mark.parametrize("embedding_dim", [5, 10]) From 2d86c43f5142bb3bae1d20a500a7fc6d2f4b45a8 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 7 Aug 2024 18:40:13 +0300 Subject: [PATCH 19/53] additional external models fixes once there is a registry --- src/scvi/external/stereoscope/_model.py | 3 ++- src/scvi/external/stereoscope/_module.py | 1 + src/scvi/model/_amortizedlda.py | 3 ++- src/scvi/model/_autozi.py | 3 ++- src/scvi/model/_condscvi.py | 3 ++- src/scvi/model/_jaxscvi.py | 3 ++- src/scvi/model/_linear_scvi.py | 3 ++- src/scvi/model/_multivi.py | 1 + src/scvi/model/_peakvi.py | 3 ++- src/scvi/model/_totalvi.py | 3 ++- src/scvi/model/base/_save_load.py | 2 +- src/scvi/model/base/_training_mixin.py | 5 ++++- 12 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/scvi/external/stereoscope/_model.py b/src/scvi/external/stereoscope/_model.py index 10f8e1f089..a67b66a9ff 100644 --- a/src/scvi/external/stereoscope/_model.py +++ b/src/scvi/external/stereoscope/_model.py @@ -49,7 +49,8 @@ class RNAStereoscope(UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, - sc_adata: AnnData, + sc_adata: AnnData | None = None, + registry: dict | None = None, **model_kwargs, ): super().__init__(sc_adata) diff --git a/src/scvi/external/stereoscope/_module.py b/src/scvi/external/stereoscope/_module.py index eefb2eb139..f74977d3ec 100644 --- a/src/scvi/external/stereoscope/_module.py +++ b/src/scvi/external/stereoscope/_module.py @@ -140,6 +140,7 @@ def __init__( n_spots: int, sc_params: tuple[np.ndarray], prior_weight: Literal["n_obs", "minibatch"] = "n_obs", + **model_kwargs, ): super().__init__() # unpack and copy parameters diff --git a/src/scvi/model/_amortizedlda.py b/src/scvi/model/_amortizedlda.py index 80a195935b..6cf022cf27 100644 --- a/src/scvi/model/_amortizedlda.py +++ b/src/scvi/model/_amortizedlda.py @@ -56,7 +56,8 @@ class AmortizedLDA(PyroSviTrainMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_topics: int = 20, n_hidden: int = 128, cell_topic_prior: Optional[Union[float, Sequence[float]]] = None, diff --git a/src/scvi/model/_autozi.py b/src/scvi/model/_autozi.py index e88c2afa25..7f0b7ee9d4 100644 --- a/src/scvi/model/_autozi.py +++ b/src/scvi/model/_autozi.py @@ -98,7 +98,8 @@ class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index b154d5e55e..306d95ed98 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -64,7 +64,8 @@ class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass) def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 5, n_layers: int = 2, diff --git a/src/scvi/model/_jaxscvi.py b/src/scvi/model/_jaxscvi.py index 15f961aa5c..e7f5b65d45 100644 --- a/src/scvi/model/_jaxscvi.py +++ b/src/scvi/model/_jaxscvi.py @@ -53,7 +53,8 @@ class JaxSCVI(JaxTrainingMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, dropout_rate: float = 0.1, diff --git a/src/scvi/model/_linear_scvi.py b/src/scvi/model/_linear_scvi.py index 201c8ac5b9..b735a21f41 100644 --- a/src/scvi/model/_linear_scvi.py +++ b/src/scvi/model/_linear_scvi.py @@ -72,7 +72,8 @@ class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClas def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index f28d6c1ab4..2fcb6251a1 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -136,6 +136,7 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): def __init__( self, adata: AnnData, + registry: dict, n_genes: int, n_regions: int, modality_weights: Literal["equal", "cell", "universal"] = "equal", diff --git a/src/scvi/model/_peakvi.py b/src/scvi/model/_peakvi.py index 890cbde26d..30017cf3c6 100644 --- a/src/scvi/model/_peakvi.py +++ b/src/scvi/model/_peakvi.py @@ -88,7 +88,8 @@ class PEAKVI(ArchesMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int | None = None, n_latent: int | None = None, n_layers_encoder: int = 2, diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index a3d9f4da54..5e6ec90388 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -103,7 +103,8 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_latent: int = 20, gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein", diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index 02af66efe0..801f9947f8 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -131,7 +131,7 @@ def _initialize_model(cls, adata, registry, attr_dict): if not adata: adata = None - model = cls(adata=adata, registry=registry, **non_kwargs, **kwargs) + model = cls(adata, registry=registry, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index 21c6ed6b59..0e546484c4 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -82,7 +82,10 @@ def train( Additional keyword arguments passed into :class:`~scvi.train.Trainer`. """ if max_epochs is None: - max_epochs = get_max_epochs_heuristic(self.summary_stats.n_obs) + if self.adata is not None: + max_epochs = get_max_epochs_heuristic(self.adata.n_obs) + else: + max_epochs = get_max_epochs_heuristic(self.summary_stats.n_obs) if datamodule is None: # In the general case we enter here From 3c44d863ce625b7b59b5c441d13e5517f9c454d8 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 11 Aug 2024 14:39:42 +0300 Subject: [PATCH 20/53] fixed a few failed tests --- src/scvi/model/_multivi.py | 8 ++++---- src/scvi/model/_totalvi.py | 5 +++-- src/scvi/model/base/_base_model.py | 2 +- tests/model/test_pyro.py | 3 ++- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 2fcb6251a1..23a3976a8e 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -135,10 +135,10 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): def __init__( self, - adata: AnnData, - registry: dict, - n_genes: int, - n_regions: int, + adata: AnnData | None = None, + registry: dict | None = None, + n_genes: int | None = None, + n_regions: int | None = None, modality_weights: Literal["equal", "cell", "universal"] = "equal", modality_penalty: Literal["Jeffreys", "MMD", "None"] = "Jeffreys", n_hidden: int | None = None, diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index 5e6ec90388..61932dfcab 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -1179,8 +1179,9 @@ def get_protein_background_mean(self, adata, indices, batch_size): @setup_anndata_dsp.dedent def setup_anndata( cls, - adata: AnnData, - protein_expression_obsm_key: str, + adata: AnnData | None = None, + registry: dict | None = None, + protein_expression_obsm_key: str | None = None, protein_names_uns_key: str | None = None, batch_key: str | None = None, layer: str | None = None, diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 738e35a4e1..1c47f1ad4c 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -785,7 +785,7 @@ def load( adata = new_adata if new_adata is not None else adata registry = attr_dict.pop("registry_") - registry["setup_method_name"] = "setup_anndata" + # registry["setup_method_name"] = "setup_anndata" #do we need this line? if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") diff --git a/tests/model/test_pyro.py b/tests/model/test_pyro.py index 0115a6bfbd..3ddbbdb1c5 100644 --- a/tests/model/test_pyro.py +++ b/tests/model/test_pyro.py @@ -134,7 +134,8 @@ def list_obs_plate_vars(self): class BayesianRegressionModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, per_cell_weight=False, ): # in case any other model was created before that shares the same parameter names. From c0889d8c02f7ac6e2d35be7c9e6a5bb14f07b462 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 11 Aug 2024 19:03:44 +0300 Subject: [PATCH 21/53] fix archesmixin init and added new custom dataloader test and github action --- .../test_linux_custom_dataloader.yml | 65 +++++++ src/scvi/model/_scvi.py | 4 +- src/scvi/model/base/_archesmixin.py | 4 +- tests/dataloaders/test_custom_dataloader.py | 160 ++++++++++++++++++ 4 files changed, 228 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/test_linux_custom_dataloader.yml create mode 100644 tests/dataloaders/test_custom_dataloader.py diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml new file mode 100644 index 0000000000..fb6b9d14f7 --- /dev/null +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -0,0 +1,65 @@ +name: test (custom dataloaders) + +on: + push: + branches: [main, "[0-9]+.[0-9]+.x"] + pull_request: + schedule: + - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ${{ matrix.os }} + + defaults: + run: + shell: bash -e {0} # -e to fail on error + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.11"] + + name: integration + + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: "pip" + cache-dependency-path: "**/pyproject.toml" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel uv + python -m uv pip install --system "scvi-tools[tests] @ ." + + - name: Install Specific Branch of Repository + run: | + pip install git+https://github.com/ebezzi/cellxgene-census.git@census-scvi-datamodule + + - name: Run specific custom dataloader pytest + env: + MPLBACKEND: agg + PLATFORM: ${{ matrix.os }} + DISPLAY: :42 + COLUMNS: 120 + run: | + pytest -m custom.dataloader -v --color=yes + coverage report + + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 437d070089..ae6168170e 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -272,7 +272,7 @@ def setup_datamodule( "state_registry": { "n_obs": datamodule.n_obs, "n_vars": datamodule.n_vars, - "column_names": datamodule.vars, + "column_names": [str(i) for i in datamodule.vars], }, "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, }, @@ -306,8 +306,6 @@ def setup_datamodule( }, "setup_method_name": "setup_datamodule", } - datamodule.summary_stats = cls._get_summary_stats_from_registry(datamodule.registry) - datamodule.var_names = [str(i) for i in datamodule.vars] @staticmethod def _get_fields_for_adata_minification( diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index cac253723b..ac23d0b6d9 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -38,9 +38,9 @@ class ArchesMixin: @devices_dsp.dedent def load_query_data( cls, - adata: None | AnnOrMuData = None, + adata: AnnOrMuData = None, reference_model: Union[str, BaseModelClass] = None, - registry: None | dict = None, + registry: dict = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", device: Union[int, str] = "auto", diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py new file mode 100644 index 0000000000..a4f474365d --- /dev/null +++ b/tests/dataloaders/test_custom_dataloader.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import cellxgene_census +import numpy as np +import pandas as pd +import pytest +import tiledbsoma as soma +from cellxgene_census.experimental.ml import experiment_dataloader +from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule + +import scvi +from scvi.data import synthetic_iid + + +@pytest.custom.dataloader +def test_custom_dataloader(save_path): + # this test checks the local custom dataloder made by CZI and run several tests with it + census = cellxgene_census.open_soma(census_version="stable") + + experiment_name = "mus_musculus" + obs_value_filter = 'is_primary_data == True and tissue_general in ["kidney"] and nnz >= 3000' + # top_n_hvg = 8000 + # hvg_batch = ["assay", "suspension_type"] + # + # # For HVG, we can use the `highly_variable_genes` function provided in the Census, + # # which can compute HVGs in constant memory: + # + # query = census["census_data"][experiment_name].axis_query( + # measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter) + # ) + # hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch) + # + # hv = hvgs_df.highly_variable + # hv_idx = hv[hv].index + + hv_idx = np.arange(100) + + datamodule = CensusSCVIDataModule( + census["census_data"][experiment_name], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + batch_size=1024, + shuffle=True, + batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, + ) + + datamodule.vars = hv_idx + + scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time + + adata = synthetic_iid() + scvi.model.SCVI.setup_anndata(adata, batch_key="batch") + + model = scvi.model.SCVI(adata, n_latent=10) + model.train(max_epochs=1) + + dataloader = model._make_data_loader(adata) + _ = model.get_elbo(dataloader=dataloader) + _ = model.get_marginal_ll(dataloader=dataloader) + _ = model.get_reconstruction_error(dataloader=dataloader) + _ = model.get_latent_representation(dataloader=dataloader) + + scvi.model.SCVI.prepare_query_anndata(adata, reference_model=model) + scvi.model.SCVI.load_query_data(adata, reference_model=model) + + n_layers = 1 + n_latent = 50 + + model_census = scvi.model.SCVI( + registry=datamodule.registry, + n_layers=n_layers, + n_latent=n_latent, + gene_likelihood="nb", + encode_covariates=False, + ) + + batch_size = 1024 + train_size = 0.9 + max_epochs = 1 + + model_census.train( + datamodule=datamodule, + max_epochs=max_epochs, + batch_size=batch_size, + train_size=train_size, + early_stopping=False, + ) + + user_attributes = model_census._get_user_attributes() + user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} + + model_census.save(save_path, overwrite=True) + model_census2 = scvi.model.SCVI.load(save_path, adata=False) + + model_census2.train( + datamodule=datamodule, + max_epochs=max_epochs, + batch_size=batch_size, + train_size=train_size, + early_stopping=False, + ) + + # takes time + adata = cellxgene_census.get_anndata( + census, + organism=experiment_name, + obs_value_filter=obs_value_filter, + var_coords=hv_idx, + ) + + adata.obs["batch"] = ( + "batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str) + ).astype("category") + # adata.var_names = 'gene_'+adata.var_names #not sure we need it + + scvi.model.SCVI.prepare_query_anndata(adata, save_path) + scvi.model.SCVI.load_query_data(registry=datamodule.registry, reference_model=save_path) + + scvi.model.SCVI.prepare_query_anndata(adata, model_census2) + + model_census3 = scvi.model.SCVI.load(save_path, adata=adata) + + scvi.model.SCVI.prepare_query_anndata(adata, save_path, return_reference_var_names=True) + scvi.model.SCVI.load_query_data(adata, save_path) + + scvi.model.SCVI.prepare_query_anndata(adata, model_census3) + scvi.model.SCVI.load_query_data(adata, model_census3) + + datamodule_inference = CensusSCVIDataModule( + census["census_data"][experiment_name], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + batch_size=1024, + shuffle=False, + soma_chunk_size=50_000, + batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, + ) + + datapipe = datamodule_inference.datapipe + dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) + mapped_dataloader = ( + datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader + ) + latent = model.get_latent_representation(dataloader=mapped_dataloader) + + emb_idx = datapipe._obs_joinids + + obs_soma_joinids = adata.obs["soma_joinid"] + + obs_indexer = pd.Index(emb_idx) + idx = obs_indexer.get_indexer(obs_soma_joinids) + # Reindexing is necessary to ensure that the cells in the embedding match the ones in + # the anndata object. + adata.obsm["scvi"] = latent[idx] From 8fe043c0797f3ac56e96b03ae438a5c98cacde64 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 11 Aug 2024 19:42:05 +0300 Subject: [PATCH 22/53] fix again for from __future__ import annotations and fix the test for custom dataloaders --- .../test_linux_custom_dataloader.yml | 3 ++ src/scvi/model/_amortizedlda.py | 29 ++++++++++--------- src/scvi/model/_autozi.py | 24 +++++++-------- src/scvi/model/_jaxscvi.py | 14 +++++---- src/scvi/model/_linear_scvi.py | 10 ++++--- src/scvi/model/base/_archesmixin.py | 15 +++++----- tests/dataloaders/test_custom_dataloader.py | 2 +- 7 files changed, 53 insertions(+), 44 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index fb6b9d14f7..6ce9b75873 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -47,7 +47,10 @@ jobs: python -m uv pip install --system "scvi-tools[tests] @ ." - name: Install Specific Branch of Repository + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | + git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" pip install git+https://github.com/ebezzi/cellxgene-census.git@census-scvi-datamodule - name: Run specific custom dataloader pytest diff --git a/src/scvi/model/_amortizedlda.py b/src/scvi/model/_amortizedlda.py index 6cf022cf27..34542061d8 100644 --- a/src/scvi/model/_amortizedlda.py +++ b/src/scvi/model/_amortizedlda.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import collections.abc import logging from collections.abc import Sequence -from typing import Optional, Union import numpy as np import pandas as pd @@ -60,8 +61,8 @@ def __init__( registry: dict | None = None, n_topics: int = 20, n_hidden: int = 128, - cell_topic_prior: Optional[Union[float, Sequence[float]]] = None, - topic_feature_prior: Optional[Union[float, Sequence[float]]] = None, + cell_topic_prior: float | Sequence[float] = None, + topic_feature_prior: float | Sequence[float] = None, ): # in case any other model was created before that shares the same parameter names. pyro.clear_param_store() @@ -110,9 +111,9 @@ def __init__( def setup_anndata( cls, adata: AnnData, - layer: Optional[str] = None, + layer: None | str = None, **kwargs, - ) -> Optional[AnnData]: + ) -> None | AnnData: """%(summary)s. Parameters @@ -155,9 +156,9 @@ def get_feature_by_topic(self, n_samples=5000) -> pd.DataFrame: def get_latent_representation( self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - batch_size: Optional[int] = None, + adata: None | AnnData = None, + indices: None | Sequence[int] = None, + batch_size: None | int = None, n_samples: int = 5000, ) -> pd.DataFrame: """Converts a count matrix to an inferred topic distribution. @@ -198,9 +199,9 @@ def get_latent_representation( def get_elbo( self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - batch_size: Optional[int] = None, + adata: None | AnnData = None, + indices: None | Sequence[int] = None, + batch_size: None | int = None, ) -> float: """Return the ELBO for the data. @@ -235,9 +236,9 @@ def get_elbo( def get_perplexity( self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - batch_size: Optional[int] = None, + adata: None | AnnData = None, + indices: None | Sequence[int] = None, + batch_size: None | int = None, ) -> float: """Computes approximate perplexity for `adata`. diff --git a/src/scvi/model/_autozi.py b/src/scvi/model/_autozi.py index 7f0b7ee9d4..86d72ce9a5 100644 --- a/src/scvi/model/_autozi.py +++ b/src/scvi/model/_autozi.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import logging from collections.abc import Sequence -from typing import Literal, Optional, Union +from typing import Literal import numpy as np import torch @@ -106,8 +108,8 @@ def __init__( dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", latent_distribution: Literal["normal", "ln"] = "normal", - alpha_prior: Optional[float] = 0.5, - beta_prior: Optional[float] = 0.5, + alpha_prior: None | float = 0.5, + beta_prior: None | float = 0.5, minimal_dropout: float = 0.01, zero_inflation: str = "gene", use_observed_lib_size: bool = True, @@ -147,19 +149,17 @@ def __init__( ) self.init_params_ = self._get_init_params(locals()) - def get_alphas_betas( - self, as_numpy: bool = True - ) -> dict[str, Union[torch.Tensor, np.ndarray]]: + def get_alphas_betas(self, as_numpy: bool = True) -> dict[str, torch.Tensor | np.ndarray]: """Return parameters of Bernoulli Beta distributions in a dictionary.""" return self.module.get_alphas_betas(as_numpy=as_numpy) @torch.inference_mode() def get_marginal_ll( self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, + adata: None | AnnData = None, + indices: None | Sequence[int] = None, n_mc_samples: int = 1000, - batch_size: Optional[int] = None, + batch_size: None | int = None, ) -> float: """Return the marginal LL for the data. @@ -262,9 +262,9 @@ def get_marginal_ll( def setup_anndata( cls, adata: AnnData, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, - layer: Optional[str] = None, + batch_key: None | str = None, + labels_key: None | str = None, + layer: None | str = None, **kwargs, ): """%(summary)s. diff --git a/src/scvi/model/_jaxscvi.py b/src/scvi/model/_jaxscvi.py index e7f5b65d45..f69e800a69 100644 --- a/src/scvi/model/_jaxscvi.py +++ b/src/scvi/model/_jaxscvi.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import logging from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal import jax.numpy as jnp import numpy as np @@ -83,8 +85,8 @@ def __init__( def setup_anndata( cls, adata: AnnData, - layer: Optional[str] = None, - batch_key: Optional[str] = None, + layer: None | str = None, + batch_key: None | str = None, **kwargs, ): """%(summary)s. @@ -106,11 +108,11 @@ def setup_anndata( def get_latent_representation( self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, + adata: None | AnnData = None, + indices: None | Sequence[int] = None, give_mean: bool = True, n_samples: int = 1, - batch_size: Optional[int] = None, + batch_size: None | int = None, ) -> np.ndarray: r"""Return the latent representation for each cell. diff --git a/src/scvi/model/_linear_scvi.py b/src/scvi/model/_linear_scvi.py index b735a21f41..a37d1d76b0 100644 --- a/src/scvi/model/_linear_scvi.py +++ b/src/scvi/model/_linear_scvi.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from typing import Literal, Optional +from typing import Literal import pandas as pd from anndata import AnnData @@ -128,9 +130,9 @@ def get_loadings(self) -> pd.DataFrame: def setup_anndata( cls, adata: AnnData, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, - layer: Optional[str] = None, + batch_key: None | str = None, + labels_key: None | str = None, + layer: None | str = None, **kwargs, ): """%(summary)s. diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index ac23d0b6d9..23ccdd8cc4 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging import warnings from copy import deepcopy -from typing import Optional, Union import anndata import numpy as np @@ -39,11 +40,11 @@ class ArchesMixin: def load_query_data( cls, adata: AnnOrMuData = None, - reference_model: Union[str, BaseModelClass] = None, + reference_model: str | BaseModelClass = None, registry: dict = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", - device: Union[int, str] = "auto", + device: int | str = "auto", unfrozen: bool = False, freeze_dropout: bool = False, freeze_expression: bool = True, @@ -187,10 +188,10 @@ def load_query_data( @staticmethod def prepare_query_anndata( adata: AnnData, - reference_model: Union[str, BaseModelClass], + reference_model: str | BaseModelClass, return_reference_var_names: bool = False, inplace: bool = True, - ) -> Optional[Union[AnnData, pd.Index]]: + ) -> AnnData | pd.Index: """Prepare data for query integration. This function will return a new AnnData object with padded zeros @@ -226,10 +227,10 @@ def prepare_query_anndata( @staticmethod def prepare_query_mudata( mdata: MuData, - reference_model: Union[str, BaseModelClass], + reference_model: str | BaseModelClass, return_reference_var_names: bool = False, inplace: bool = True, - ) -> Optional[Union[MuData, dict[str, pd.Index]]]: + ) -> None | MuData | dict[str, pd.Index]: """Prepare multimodal dataset for query integration. This function will return a new MuData object such that the diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index a4f474365d..8d85f89fbf 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -13,7 +13,7 @@ @pytest.custom.dataloader -def test_custom_dataloader(save_path): +def custom_dataloader_test(save_path): # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") From d8cf0f6c10661f3555f6f2dd90eb69a460a4491e Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 11 Aug 2024 19:51:25 +0300 Subject: [PATCH 23/53] fix for run custom dataloader in github action --- .github/workflows/test_linux.yml | 2 +- .github/workflows/test_linux_custom_dataloader.yml | 3 +-- tests/dataloaders/test_custom_dataloader.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index 74765a22b8..454daaf636 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -53,7 +53,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest -v --color=yes + coverage run -m pytest "not custom.dataloader" -v --color=yes coverage report - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 6ce9b75873..7c7ec49294 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -51,7 +51,7 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - pip install git+https://github.com/ebezzi/cellxgene-census.git@census-scvi-datamodule + pip install git+https://github.com/ebezzi/chanzuckerberg/cellxgene-census.git@census-scvi-datamodule - name: Run specific custom dataloader pytest env: @@ -61,7 +61,6 @@ jobs: COLUMNS: 120 run: | pytest -m custom.dataloader -v --color=yes - coverage report - uses: codecov/codecov-action@v4 with: diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 8d85f89fbf..a4f474365d 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -13,7 +13,7 @@ @pytest.custom.dataloader -def custom_dataloader_test(save_path): +def test_custom_dataloader(save_path): # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") From c41e8b2d47ca0cf24042651220811128c1827637 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 11 Aug 2024 19:59:45 +0300 Subject: [PATCH 24/53] rollback --- .github/workflows/test_linux.yml | 2 +- .github/workflows/test_linux_custom_dataloader.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index 454daaf636..74765a22b8 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -53,7 +53,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest "not custom.dataloader" -v --color=yes + coverage run -m pytest -v --color=yes coverage report - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 7c7ec49294..cc58e6107a 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -51,7 +51,7 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - pip install git+https://github.com/ebezzi/chanzuckerberg/cellxgene-census.git@census-scvi-datamodule + pip install git+https://github.com/ebezzi/cellxgene-census.git@census-scvi-datamodule - name: Run specific custom dataloader pytest env: From 6ec5d4d3081cd5867f7f9a9d448ffff783f14ad5 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 11 Aug 2024 20:11:58 +0300 Subject: [PATCH 25/53] added label to the new githubaction for custom dataloader --- .github/workflows/test_linux_custom_dataloader.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index cc58e6107a..14714bf99f 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -4,6 +4,8 @@ on: push: branches: [main, "[0-9]+.[0-9]+.x"] pull_request: + branches: [main, "[0-9]+.[0-9]+.x"] + types: [labeled, synchronize, opened] schedule: - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day workflow_dispatch: @@ -14,6 +16,15 @@ concurrency: jobs: test: + # if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered + if: >- + ( + contains(github.event.pull_request.labels.*.name, 'custom_dataloader') || + contains(github.event.pull_request.labels.*.name, 'all tests') || + contains(github.event_name, 'schedule') || + contains(github.event_name, 'workflow_dispatch') + ) + runs-on: ${{ matrix.os }} defaults: From 6bce3173ca7f11fc38bffe78090cb71d2a3ad87d Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 12:05:07 +0300 Subject: [PATCH 26/53] fix for github action for custom dataloaders --- .../workflows/test_linux_custom_dataloader.yml | 5 +++-- tests/conftest.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 14714bf99f..8e79305eb1 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -62,7 +62,7 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - pip install git+https://github.com/ebezzi/cellxgene-census.git@census-scvi-datamodule + pip install git+https://github.com/ori-kron-wis/cellxgene-census.git@ebezzi/census-scvi-datamodule - name: Run specific custom dataloader pytest env: @@ -71,7 +71,8 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - pytest -m custom.dataloader -v --color=yes + coverage run -m pytest -v --color=yes --custom.dataloader-tests + coverage report - uses: codecov/codecov-action@v4 with: diff --git a/tests/conftest.py b/tests/conftest.py index 4fdc21a1c3..72807d8779 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,12 @@ def pytest_addoption(parser): default=False, help="Run tests that retrieve stuff from the internet. This increases test time.", ) + parser.addoption( + "--custom.dataloader-tests", + action="store_true", + default=False, + help="Run tests that deals with custom dataloaders. This increases test time.", + ) parser.addoption( "--optional", action="store_true", @@ -55,13 +61,22 @@ def pytest_configure(config): def pytest_collection_modifyitems(config, items): """Docstring for pytest_collection_modifyitems.""" run_internet = config.getoption("--internet-tests") + run_custom_dataloader = config.getoption("--custom.dataloader-tests") skip_internet = pytest.mark.skip(reason="need --internet-tests option to run") + skip_custom_dataloader = pytest.mark.skip( + reason="need ---custom.dataloader-tests option to run" + ) for item in items: # All tests marked with `pytest.mark.internet` get skipped unless # `--internet-tests` passed if not run_internet and ("internet" in item.keywords): item.add_marker(skip_internet) + # All tests marked with `pytest.custom.dataloader` get skipped unless + # `--custom.dataloader-tests` passed + if not run_custom_dataloader and ("custom.dataloader" in item.keywords): + item.add_marker(skip_custom_dataloader) + run_optional = config.getoption("--optional") skip_optional = pytest.mark.skip(reason="need --optional option to run") for item in items: From 1f4ae9dff1704a05f3901baa97b107a6d9342758 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 12:23:26 +0300 Subject: [PATCH 27/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux_custom_dataloader.yml | 5 ++++- tests/conftest.py | 2 +- tests/dataloaders/test_custom_dataloader.py | 6 ++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 8e79305eb1..7d3c485a04 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -62,7 +62,10 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - pip install git+https://github.com/ori-kron-wis/cellxgene-census.git@ebezzi/census-scvi-datamodule + #pip install git+https://github.com/ebezzi/cellxgene-census.git@ebezzi/census-scvi-datamodule + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ebezzi/cellxgene-census.git + #cd cellxgene-census + #python -m pip install - name: Run specific custom dataloader pytest env: diff --git a/tests/conftest.py b/tests/conftest.py index 72807d8779..9d15938979 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ def pytest_addoption(parser): ) parser.addoption( "--custom.dataloader-tests", - action="store_true", + action="store_false", default=False, help="Run tests that deals with custom dataloaders. This increases test time.", ) diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index a4f474365d..32a4dfacf3 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,5 +1,11 @@ from __future__ import annotations +import sys + +# the next should be ready for improting +sys.path.insert(0, "/cellxgene-census/api/python/cellxgene_census/src") +sys.path.insert(0, "src") + import cellxgene_census import numpy as np import pandas as pd From de1f30bf33bb868de62bf787150bb9aab82f016b Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 12:29:57 +0300 Subject: [PATCH 28/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux_custom_dataloader.yml | 4 ++-- tests/conftest.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 7d3c485a04..adf0d99c91 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -62,8 +62,8 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - #pip install git+https://github.com/ebezzi/cellxgene-census.git@ebezzi/census-scvi-datamodule - git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ebezzi/cellxgene-census.git + #pip install git+https://github.com/ori-kron-wis/cellxgene-census.git@ebezzi/census-scvi-datamodule + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git #cd cellxgene-census #python -m pip install diff --git a/tests/conftest.py b/tests/conftest.py index 9d15938979..72807d8779 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ def pytest_addoption(parser): ) parser.addoption( "--custom.dataloader-tests", - action="store_false", + action="store_true", default=False, help="Run tests that deals with custom dataloaders. This increases test time.", ) From 49fa01e8d7e5098fd0cb88a422b5061ac42a7f18 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 12:39:26 +0300 Subject: [PATCH 29/53] another fix to custom dataloder test and github action --- tests/conftest.py | 2 +- tests/dataloaders/test_custom_dataloader.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 72807d8779..00bc5999df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,7 +72,7 @@ def pytest_collection_modifyitems(config, items): if not run_internet and ("internet" in item.keywords): item.add_marker(skip_internet) - # All tests marked with `pytest.custom.dataloader` get skipped unless + # All tests marked with `pytest.mark.custom.dataloader` get skipped unless # `--custom.dataloader-tests` passed if not run_custom_dataloader and ("custom.dataloader" in item.keywords): item.add_marker(skip_custom_dataloader) diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 32a4dfacf3..866f9c2970 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -2,8 +2,11 @@ import sys -# the next should be ready for improting -sys.path.insert(0, "/cellxgene-census/api/python/cellxgene_census/src") +# should be ready for importing the cloned branch on a remote machine that runs github action +sys.path.insert( + 0, + "/home/runner/work/scvi-tools/scvi-tools/" "cellxgene-census/api/python/cellxgene_census/src", +) sys.path.insert(0, "src") import cellxgene_census @@ -18,7 +21,7 @@ from scvi.data import synthetic_iid -@pytest.custom.dataloader +@pytest.mark.custom.dataloader def test_custom_dataloader(save_path): # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") From e33a935e3f31c30a2fc40975f0e6a165bde32b5f Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 12:46:03 +0300 Subject: [PATCH 30/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux_custom_dataloader.yml | 3 ++- tests/conftest.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index adf0d99c91..23832173ba 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -56,6 +56,7 @@ jobs: run: | python -m pip install --upgrade pip wheel uv python -m uv pip install --system "scvi-tools[tests] @ ." + python -m pip install tiledbsoma - name: Install Specific Branch of Repository env: @@ -74,7 +75,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest -v --color=yes --custom.dataloader-tests + coverage run -m pytest -v --color=yes --custom-dataloader-tests coverage report - uses: codecov/codecov-action@v4 diff --git a/tests/conftest.py b/tests/conftest.py index 00bc5999df..afa62a1060 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ def pytest_addoption(parser): help="Run tests that retrieve stuff from the internet. This increases test time.", ) parser.addoption( - "--custom.dataloader-tests", + "--custom-dataloader-tests", action="store_true", default=False, help="Run tests that deals with custom dataloaders. This increases test time.", @@ -61,10 +61,10 @@ def pytest_configure(config): def pytest_collection_modifyitems(config, items): """Docstring for pytest_collection_modifyitems.""" run_internet = config.getoption("--internet-tests") - run_custom_dataloader = config.getoption("--custom.dataloader-tests") + run_custom_dataloader = config.getoption("--custom-dataloader-tests") skip_internet = pytest.mark.skip(reason="need --internet-tests option to run") skip_custom_dataloader = pytest.mark.skip( - reason="need ---custom.dataloader-tests option to run" + reason="need ---custom-dataloader-tests option to run" ) for item in items: # All tests marked with `pytest.mark.internet` get skipped unless @@ -74,7 +74,7 @@ def pytest_collection_modifyitems(config, items): # All tests marked with `pytest.mark.custom.dataloader` get skipped unless # `--custom.dataloader-tests` passed - if not run_custom_dataloader and ("custom.dataloader" in item.keywords): + if not run_custom_dataloader and ("dataloader" in item.keywords): item.add_marker(skip_custom_dataloader) run_optional = config.getoption("--optional") From 48627d93324de2740378d6233e83c59499c0e5fe Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 12:49:13 +0300 Subject: [PATCH 31/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux_custom_dataloader.yml | 4 +--- tests/dataloaders/test_custom_dataloader.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 23832173ba..4a0617f294 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -57,16 +57,14 @@ jobs: python -m pip install --upgrade pip wheel uv python -m uv pip install --system "scvi-tools[tests] @ ." python -m pip install tiledbsoma + python -m pip install s3fs - name: Install Specific Branch of Repository env: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - #pip install git+https://github.com/ori-kron-wis/cellxgene-census.git@ebezzi/census-scvi-datamodule git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git - #cd cellxgene-census - #python -m pip install - name: Run specific custom dataloader pytest env: diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 866f9c2970..93751a69d5 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -21,7 +21,7 @@ from scvi.data import synthetic_iid -@pytest.mark.custom.dataloader +@pytest.mark.internet def test_custom_dataloader(save_path): # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") From 609094d568b9f22e8df945c9f7a459ce5c290026 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 12:53:14 +0300 Subject: [PATCH 32/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux.yml | 2 +- .github/workflows/test_linux_custom_dataloader.yml | 1 + tests/dataloaders/test_custom_dataloader.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index 74765a22b8..7273d8499b 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -53,7 +53,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest -v --color=yes + coverage run -m pytest -v --color=yes -m "not custom.dataloader" coverage report - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 4a0617f294..0f0bfb7bef 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -58,6 +58,7 @@ jobs: python -m uv pip install --system "scvi-tools[tests] @ ." python -m pip install tiledbsoma python -m pip install s3fs + python -m pip install torchdata - name: Install Specific Branch of Repository env: diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 93751a69d5..866f9c2970 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -21,7 +21,7 @@ from scvi.data import synthetic_iid -@pytest.mark.internet +@pytest.mark.custom.dataloader def test_custom_dataloader(save_path): # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") From 8cf3517c613c88d9cfb7cd801b3d9add144f708c Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 13:07:30 +0300 Subject: [PATCH 33/53] another fix to custom dataloder test and github action --- tests/dataloaders/test_custom_dataloader.py | 28 +++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 866f9c2970..0f61292ae1 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,21 +1,8 @@ from __future__ import annotations -import sys - -# should be ready for importing the cloned branch on a remote machine that runs github action -sys.path.insert( - 0, - "/home/runner/work/scvi-tools/scvi-tools/" "cellxgene-census/api/python/cellxgene_census/src", -) -sys.path.insert(0, "src") - -import cellxgene_census import numpy as np import pandas as pd import pytest -import tiledbsoma as soma -from cellxgene_census.experimental.ml import experiment_dataloader -from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule import scvi from scvi.data import synthetic_iid @@ -23,6 +10,21 @@ @pytest.mark.custom.dataloader def test_custom_dataloader(save_path): + # local bracnh with fix only for this test + import sys + + # should be ready for importing the cloned branch on a remote machine that runs github action + sys.path.insert( + 0, + "/home/runner/work/scvi-tools/scvi-tools/" + "cellxgene-census/api/python/cellxgene_census/src", + ) + sys.path.insert(0, "src") + import cellxgene_census + import tiledbsoma as soma + from cellxgene_census.experimental.ml import experiment_dataloader + from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule + # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") From ba5a0281c6866cb6649161dcf2dbd56e668347e7 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 13:11:42 +0300 Subject: [PATCH 34/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux.yml | 2 +- tests/conftest.py | 4 ++-- tests/dataloaders/test_custom_dataloader.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index 7273d8499b..6eb8ff4c13 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -53,7 +53,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest -v --color=yes -m "not custom.dataloader" + coverage run -m pytest -v --color=yes -m "not custom_dataloader" coverage report - uses: codecov/codecov-action@v4 diff --git a/tests/conftest.py b/tests/conftest.py index afa62a1060..81566f563a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,8 +72,8 @@ def pytest_collection_modifyitems(config, items): if not run_internet and ("internet" in item.keywords): item.add_marker(skip_internet) - # All tests marked with `pytest.mark.custom.dataloader` get skipped unless - # `--custom.dataloader-tests` passed + # All tests marked with `pytest.mark.custom_dataloader` get skipped unless + # `--custom_dataloader-tests` passed if not run_custom_dataloader and ("dataloader" in item.keywords): item.add_marker(skip_custom_dataloader) diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 0f61292ae1..ebc13af852 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -8,7 +8,7 @@ from scvi.data import synthetic_iid -@pytest.mark.custom.dataloader +@pytest.mark.custom_dataloader def test_custom_dataloader(save_path): # local bracnh with fix only for this test import sys From a7dc3fea8a510f0db71ed72a7b1aa0d06eb5a5ff Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 13:20:03 +0300 Subject: [PATCH 35/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux_custom_dataloader.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 0f0bfb7bef..e00e010f14 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -74,7 +74,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest -v --color=yes --custom-dataloader-tests + coverage run -m pytest tests/dataloaders/test_custom_dataloader.py -v --color=yes --custom-dataloader-tests coverage report - uses: codecov/codecov-action@v4 From f3ff0f89f9f44700e15ea89a9a303fc2fc833156 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 12 Aug 2024 13:22:28 +0300 Subject: [PATCH 36/53] another fix to custom dataloder test and github action --- .github/workflows/test_linux_custom_dataloader.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index e00e010f14..09c5ce54b1 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -59,6 +59,7 @@ jobs: python -m pip install tiledbsoma python -m pip install s3fs python -m pip install torchdata + python -m pip install psutil - name: Install Specific Branch of Repository env: From b6eb2f1221091a944ff70ffba3bf35e2b14d33a5 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 16 Sep 2024 13:05:06 +0300 Subject: [PATCH 37/53] Returned REGISTRY_KEYS for import, after was drop in recent merges --- src/scvi/model/base/_archesmixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 793e109265..33fcfdf2b3 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -12,7 +12,7 @@ from mudata import MuData from scipy.sparse import csr_matrix -from scvi import settings +from scvi import REGISTRY_KEYS, settings from scvi._types import AnnOrMuData from scvi.data import _constants from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME From 2979ea24f152e81dfd337ad8cbef8bb6e14d5fd3 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 16 Sep 2024 13:22:20 +0300 Subject: [PATCH 38/53] It is ok to drop it after scarches categorial covariates fix --- src/scvi/model/base/_archesmixin.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 33fcfdf2b3..bf93d40c51 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -12,7 +12,7 @@ from mudata import MuData from scipy.sparse import csr_matrix -from scvi import REGISTRY_KEYS, settings +from scvi import settings from scvi._types import AnnOrMuData from scvi.data import _constants from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME @@ -140,11 +140,6 @@ def load_query_data( model = _initialize_model(cls, adata, registry, attr_dict) - if model.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] > 0: - raise NotImplementedError( - "scArches currently does not support models with extra categorical covariates." - ) - version_split = model.registry[_constants._SCVI_VERSION_KEY].split(".") if int(version_split[1]) < 8 and int(version_split[0]) == 0: From 4a648ffb8a73271ead898c60754755e30dfc51e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Sep 2024 21:01:25 +0000 Subject: [PATCH 39/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_scvi.py | 6 ++---- src/scvi/model/base/_base_model.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 9232dcb19a..7741697758 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -1,15 +1,13 @@ from __future__ import annotations import logging -import warnings from typing import TYPE_CHECKING import numpy as np import scvi -from scvi.data import _constants -from scvi import REGISTRY_KEYS, settings -from scvi.data import AnnDataManager +from scvi import REGISTRY_KEYS +from scvi.data import AnnDataManager, _constants from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 664069621c..5ea73b2dab 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -50,6 +50,7 @@ from scvi.utils._docstrings import devices_dsp from . import _constants + if TYPE_CHECKING: from collections.abc import Sequence From e3831cb446b3faa8410e7efbfdf2532d10f15134 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 18 Sep 2024 00:19:07 +0300 Subject: [PATCH 40/53] moved to type checking blocks beucase of ruff updates --- src/scvi/model/base/_archesmixin.py | 7 +++++-- src/scvi/model/base/_base_model.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index bf93d40c51..55867b75e4 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -3,6 +3,7 @@ import logging import warnings from copy import deepcopy +from typing import TYPE_CHECKING import anndata import numpy as np @@ -13,7 +14,6 @@ from scipy.sparse import csr_matrix from scvi import settings -from scvi._types import AnnOrMuData from scvi.data import _constants from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME from scvi.model._utils import parse_device_args @@ -25,7 +25,10 @@ from scvi.nn import FCLayers from scvi.utils._docstrings import devices_dsp -from ._base_model import BaseModelClass +if TYPE_CHECKING: + from scvi._types import AnnOrMuData + + from ._base_model import BaseModelClass logger = logging.getLogger(__name__) diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 5ea73b2dab..3af9368d5c 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -11,7 +11,6 @@ from uuid import uuid4 import numpy as np -import pandas as pd import rich import torch from anndata import AnnData @@ -54,6 +53,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + import pandas as pd + from scvi._types import AnnOrMuData, MinifiedDataType logger = logging.getLogger(__name__) From 2cc8ff90b847689b5639d77a7ea94508a7923380 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 9 Oct 2024 19:02:12 +0300 Subject: [PATCH 41/53] updated for CZI custom dataloader test and backend --- src/scvi/model/_scvi.py | 27 ++++- ...oader.py => test_czi_custom_dataloader.py} | 106 ++++++++++++------ 2 files changed, 96 insertions(+), 37 deletions(-) rename tests/dataloaders/{test_custom_dataloader.py => test_czi_custom_dataloader.py} (55%) diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 7741697758..3b9b7708a6 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -242,6 +242,7 @@ def _get_summary_stats_from_registry(registry: dict) -> attrdict: def setup_datamodule( cls, datamodule, # TODO: what to put here? + source_registry=None, layer: str | None = None, batch_key: list[str] | None = None, labels_key: str | None = None, @@ -262,6 +263,26 @@ def setup_datamodule( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ + # TODO: from adata (czi)? + if datamodule.__class__.__name__ == "CensusSCVIDataModule": + # CZI + categorical_mapping = datamodule.datapipe.obs_encoders["batch"].classes_ + column_names = list( + datamodule.datapipe.var_query.coords[0] + if datamodule.datapipe.var_query is not None + else range(datamodule.n_vars) + ) + n_batch = datamodule.n_batch + else: + # Anndata -> CZI + # if we are here and datamodule is actually an AnnData object + # it means we init the custom dataloder model with anndata + categorical_mapping = source_registry["field_registries"]["batch"]["state_registry"][ + "categorical_mapping" + ] + column_names = datamodule.var.soma_joinid.values + n_batch = source_registry["field_registries"]["batch"]["summary_stats"]["n_batch"] + datamodule.registry = { "scvi_version": scvi.__version__, "model_name": "SCVI", @@ -279,17 +300,17 @@ def setup_datamodule( "state_registry": { "n_obs": datamodule.n_obs, "n_vars": datamodule.n_vars, - "column_names": [str(i) for i in datamodule.vars], + "column_names": [str(i) for i in column_names], # TODO: from adata (czi)? }, "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, }, "batch": { "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, "state_registry": { - "categorical_mapping": datamodule.datapipe.obs_encoders["batch"].classes_, + "categorical_mapping": categorical_mapping, "original_key": "batch", }, - "summary_stats": {"n_batch": datamodule.n_batch}, + "summary_stats": {"n_batch": n_batch}, }, "labels": { "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_czi_custom_dataloader.py similarity index 55% rename from tests/dataloaders/test_custom_dataloader.py rename to tests/dataloaders/test_czi_custom_dataloader.py index ebc13af852..060b9f6fed 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_czi_custom_dataloader.py @@ -1,35 +1,28 @@ from __future__ import annotations +from pprint import pprint + +import cellxgene_census import numpy as np import pandas as pd import pytest +import tiledbsoma as soma +from cellxgene_census.experimental.ml import experiment_dataloader +from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule import scvi from scvi.data import synthetic_iid @pytest.mark.custom_dataloader -def test_custom_dataloader(save_path): - # local bracnh with fix only for this test - import sys - - # should be ready for importing the cloned branch on a remote machine that runs github action - sys.path.insert( - 0, - "/home/runner/work/scvi-tools/scvi-tools/" - "cellxgene-census/api/python/cellxgene_census/src", - ) - sys.path.insert(0, "src") - import cellxgene_census - import tiledbsoma as soma - from cellxgene_census.experimental.ml import experiment_dataloader - from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule - +def test_czi_custom_dataloader(save_path="."): # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") experiment_name = "mus_musculus" obs_value_filter = 'is_primary_data == True and tissue_general in ["kidney"] and nnz >= 3000' + + # This is under comments just to save time (selecting highly varkable genes): # top_n_hvg = 8000 # hvg_batch = ["assay", "suspension_type"] # @@ -44,8 +37,10 @@ def test_custom_dataloader(save_path): # hv = hvgs_df.highly_variable # hv_idx = hv[hv].index - hv_idx = np.arange(100) + hv_idx = np.arange(100) # just ot make it smaller and faster for debug + # this is CZI part to be taken once all is ready + batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"] datamodule = CensusSCVIDataModule( census["census_data"][experiment_name], measurement_name="RNA", @@ -54,32 +49,41 @@ def test_custom_dataloader(save_path): var_query=soma.AxisQuery(coords=(list(hv_idx),)), batch_size=1024, shuffle=True, - batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + batch_keys=batch_keys, dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, ) - datamodule.vars = hv_idx + # table of genes should be filtered by soma_joinid - but we should keep the encoded indexes + # This is nice to have and might be uses in the downstream analysis + # var_df = census["census_data"][experiment_name].ms["RNA"].var.read().concat().to_pandas() + # var_df = var_df.loc[var_df.soma_joinid.isin( + # list(datamodule.datapipe.var_query.coords[0] if datamodule.datapipe.var_query is not None + # else range(datamodule.n_vars)))] - scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time - - adata = synthetic_iid() - scvi.model.SCVI.setup_anndata(adata, batch_key="batch") + # basicaly we should mimin everything below to any model census in scvi + adata_orig = synthetic_iid() + scvi.model.SCVI.setup_anndata(adata_orig, batch_key="batch") - model = scvi.model.SCVI(adata, n_latent=10) + model = scvi.model.SCVI(adata_orig, n_latent=10) model.train(max_epochs=1) - dataloader = model._make_data_loader(adata) + # TODO: do we need to apply those functions to any census model as is? + dataloader = model._make_data_loader(adata_orig) _ = model.get_elbo(dataloader=dataloader) _ = model.get_marginal_ll(dataloader=dataloader) _ = model.get_reconstruction_error(dataloader=dataloader) _ = model.get_latent_representation(dataloader=dataloader) - scvi.model.SCVI.prepare_query_anndata(adata, reference_model=model) - scvi.model.SCVI.load_query_data(adata, reference_model=model) + scvi.model.SCVI.prepare_query_anndata(adata_orig, reference_model=model) + scvi.model.SCVI.load_query_data(adata_orig, reference_model=model) + + user_attributes = model._get_user_attributes() + pprint(user_attributes) n_layers = 1 n_latent = 50 + scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time model_census = scvi.model.SCVI( registry=datamodule.registry, n_layers=n_layers, @@ -88,6 +92,8 @@ def test_custom_dataloader(save_path): encode_covariates=False, ) + pprint(datamodule.registry) + batch_size = 1024 train_size = 0.9 max_epochs = 1 @@ -100,8 +106,17 @@ def test_custom_dataloader(save_path): early_stopping=False, ) - user_attributes = model_census._get_user_attributes() - user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} + user_attributes_model_census = model_census._get_user_attributes() + # # TODO: do we need to put inside + # user_attributes_model_census = \ + # {a[0]: a[1] for a in user_attributes_model_census if a[0][-1] == "_"} + pprint(user_attributes_model_census) + # dataloader_census = model_census._make_data_loader(datamodule.datapipe) + # # this casus errors + # _ = model_census.get_elbo(dataloader=dataloader_census) + # _ = model_census.get_marginal_ll(dataloader=dataloader_census) + # _ = model_census.get_reconstruction_error(dataloader=dataloader_census) + # _ = model_census.get_latent_representation(dataloader=dataloader_census) model_census.save(save_path, overwrite=True) model_census2 = scvi.model.SCVI.load(save_path, adata=False) @@ -114,6 +129,15 @@ def test_custom_dataloader(save_path): early_stopping=False, ) + user_attributes_model_census2 = model_census2._get_user_attributes() + pprint(user_attributes_model_census2) + # dataloader_census2 = model_census2._make_data_loader() + # this casus errors + # _ = model_census2.get_elbo() + # _ = model_census2.get_marginal_ll() + # _ = model_census2.get_reconstruction_error() + # _ = model_census2.get_latent_representation() + # takes time adata = cellxgene_census.get_anndata( census, @@ -122,18 +146,32 @@ def test_custom_dataloader(save_path): var_coords=hv_idx, ) - adata.obs["batch"] = ( - "batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str) - ).astype("category") - # adata.var_names = 'gene_'+adata.var_names #not sure we need it + # TODO: do we need to put inside (or is it alrady pre-made) - perhaps need to tell CZI + adata.obs["batch"] = adata.obs[batch_keys].agg("".join, axis=1).astype("category") scvi.model.SCVI.prepare_query_anndata(adata, save_path) scvi.model.SCVI.load_query_data(registry=datamodule.registry, reference_model=save_path) scvi.model.SCVI.prepare_query_anndata(adata, model_census2) + scvi.model.SCVI.setup_anndata(adata, batch_key="batch") # needed? model_census3 = scvi.model.SCVI.load(save_path, adata=adata) + model_census3.train( + datamodule=datamodule, + max_epochs=max_epochs, + batch_size=batch_size, + train_size=train_size, + early_stopping=False, + ) + + user_attributes_model_census3 = model_census3._get_user_attributes() + pprint(user_attributes_model_census3) + _ = model_census3.get_elbo() + _ = model_census3.get_marginal_ll() + _ = model_census3.get_reconstruction_error() + _ = model_census3.get_latent_representation() + scvi.model.SCVI.prepare_query_anndata(adata, save_path, return_reference_var_names=True) scvi.model.SCVI.load_query_data(adata, save_path) @@ -158,7 +196,7 @@ def test_custom_dataloader(save_path): mapped_dataloader = ( datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader ) - latent = model.get_latent_representation(dataloader=mapped_dataloader) + latent = model_census.get_latent_representation(dataloader=mapped_dataloader) emb_idx = datapipe._obs_joinids From 41fd877f7f450ef6b5af4b4b8399e439dc1811e2 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 9 Oct 2024 19:07:11 +0300 Subject: [PATCH 42/53] added cellxgene-census folder as well for debug (will not be merged) --- cellxgene-census | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellxgene-census b/cellxgene-census index 6edd123100..9ae33ae537 160000 --- a/cellxgene-census +++ b/cellxgene-census @@ -1 +1 @@ -Subproject commit 6edd123100716f6a434403b74db58c5379bb0d5d +Subproject commit 9ae33ae53787cb65e03c462f79acb67f7fbfc76c From 10ada9cb588bdcacd5c233cfc95abae92b4f9cba Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 9 Oct 2024 19:26:04 +0300 Subject: [PATCH 43/53] added cellxgene-census packge to run test --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 18d55569d4..d9162fee27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "torch", "torchmetrics>=0.11.0", "tqdm", + "cellxgene-census", "xarray>=2023.2.0", ] @@ -81,8 +82,6 @@ docsbuild = ["scvi-tools[docs,optional]"] autotune = ["hyperopt>=0.2", "ray[tune]>=2.5.0"] # scvi.hub.HubModel.pull_from_s3 aws = ["boto3"] -# scvi.data.cellxgene -census = ["cellxgene-census"] # scvi.hub dependencies hub = ["huggingface_hub"] # scvi.model.utils.mde dependencies From dd3649c051424af1107817a7b9c95f06adf87aab Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 9 Oct 2024 19:32:06 +0300 Subject: [PATCH 44/53] added torchdata packge to run test --- .github/workflows/test_linux_custom_dataloader.yml | 5 +++-- pyproject.toml | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 09c5ce54b1..53d6de3353 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -66,7 +66,8 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/chanzuckerberg/cellxgene-census.git + git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git - name: Run specific custom dataloader pytest env: @@ -75,7 +76,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest tests/dataloaders/test_custom_dataloader.py -v --color=yes --custom-dataloader-tests + coverage run -m pytest tests/dataloaders/test_czi_custom_dataloader.py -v --color=yes --custom-dataloader-tests coverage report - uses: codecov/codecov-action@v4 diff --git a/pyproject.toml b/pyproject.toml index d9162fee27..bd8e42ca22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "torchmetrics>=0.11.0", "tqdm", "cellxgene-census", + "torchdata", "xarray>=2023.2.0", ] From c6acb5abec14810606d75c7fe1da3c1668d1199b Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 9 Oct 2024 19:41:53 +0300 Subject: [PATCH 45/53] fixed the test workwflow --- .../test_linux_custom_dataloader.yml | 2 +- .../dataloaders/test_czi_custom_dataloader.py | 21 ++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 53d6de3353..30e2e1d222 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -66,7 +66,7 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/chanzuckerberg/cellxgene-census.git + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git - name: Run specific custom dataloader pytest diff --git a/tests/dataloaders/test_czi_custom_dataloader.py b/tests/dataloaders/test_czi_custom_dataloader.py index 060b9f6fed..662c6d6166 100644 --- a/tests/dataloaders/test_czi_custom_dataloader.py +++ b/tests/dataloaders/test_czi_custom_dataloader.py @@ -2,20 +2,31 @@ from pprint import pprint -import cellxgene_census import numpy as np import pandas as pd import pytest -import tiledbsoma as soma -from cellxgene_census.experimental.ml import experiment_dataloader -from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule import scvi from scvi.data import synthetic_iid @pytest.mark.custom_dataloader -def test_czi_custom_dataloader(save_path="."): +def test_czi_custom_dataloader(save_path): + # local bracnh with fix only for this test + import sys + + # should be ready for importing the cloned branch on a remote machine that runs github action + sys.path.insert( + 0, + "/home/runner/work/scvi-tools/scvi-tools/" + "cellxgene-census/api/python/cellxgene_census/src", + ) + sys.path.insert(0, "src") + import cellxgene_census + import tiledbsoma as soma + from cellxgene_census.experimental.ml import experiment_dataloader + from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule + # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") From b35c6eb8fd0eaa50fffc40c19b9a44b730e4e9e6 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 10 Oct 2024 16:24:01 +0300 Subject: [PATCH 46/53] adding the lamindb as well --- .../test_linux_custom_dataloader.yml | 2 +- pyproject.toml | 7 +- ...ataloader.py => test_custom_dataloader.py} | 189 +++++++++++++++++- 3 files changed, 191 insertions(+), 7 deletions(-) rename tests/dataloaders/{test_czi_custom_dataloader.py => test_custom_dataloader.py} (55%) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 30e2e1d222..53d6de3353 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -66,7 +66,7 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/chanzuckerberg/cellxgene-census.git git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git - name: Run specific custom dataloader pytest diff --git a/pyproject.toml b/pyproject.toml index bd8e42ca22..031467a0d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,8 +53,6 @@ dependencies = [ "torch", "torchmetrics>=0.11.0", "tqdm", - "cellxgene-census", - "torchdata", "xarray>=2023.2.0", ] @@ -83,6 +81,8 @@ docsbuild = ["scvi-tools[docs,optional]"] autotune = ["hyperopt>=0.2", "ray[tune]>=2.5.0"] # scvi.hub.HubModel.pull_from_s3 aws = ["boto3"] +# scvi.data.cellxgene +census = ["cellxgene-census"] # scvi.hub dependencies hub = ["huggingface_hub"] # scvi.model.utils.mde dependencies @@ -112,6 +112,9 @@ tutorials = [ "scvi-tools[optional]", "squidpy", ] +dataloaders = [ + "scdataloader" +] all = ["scvi-tools[dev,docs,tutorials]"] diff --git a/tests/dataloaders/test_czi_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py similarity index 55% rename from tests/dataloaders/test_czi_custom_dataloader.py rename to tests/dataloaders/test_custom_dataloader.py index 662c6d6166..b8034bb6af 100644 --- a/tests/dataloaders/test_czi_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,18 +1,18 @@ -from __future__ import annotations - +import os from pprint import pprint import numpy as np import pandas as pd import pytest +import scanpy as sc import scvi from scvi.data import synthetic_iid @pytest.mark.custom_dataloader -def test_czi_custom_dataloader(save_path): - # local bracnh with fix only for this test +def test_czi_custom_dataloader(save_path="."): + # local branch with fix only for this test import sys # should be ready for importing the cloned branch on a remote machine that runs github action @@ -26,6 +26,7 @@ def test_czi_custom_dataloader(save_path): import tiledbsoma as soma from cellxgene_census.experimental.ml import experiment_dataloader from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule + # from cellxgene_census.experimental.pp import highly_variable_genes # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") @@ -202,11 +203,15 @@ def test_czi_custom_dataloader(save_path): dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, ) + # Create a dataloder of a CZI module datapipe = datamodule_inference.datapipe dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) mapped_dataloader = ( datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader ) + _ = model_census.get_elbo(dataloader=mapped_dataloader) + _ = model_census.get_marginal_ll(dataloader=mapped_dataloader) + _ = model_census.get_reconstruction_error(dataloader=mapped_dataloader) latent = model_census.get_latent_representation(dataloader=mapped_dataloader) emb_idx = datapipe._obs_joinids @@ -218,3 +223,179 @@ def test_czi_custom_dataloader(save_path): # Reindexing is necessary to ensure that the cells in the embedding match the ones in # the anndata object. adata.obsm["scvi"] = latent[idx] + + # #We can now generate the neighbors and the UMAP (tutorials) + # sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") + # sc.tl.umap(adata, neighbors_key="scvi") + # sc.pl.umap(adata, color="dataset_id", title="SCVI") + # + # sc.pl.umap(adata, color="tissue_general", title="SCVI") + # + # sc.pl.umap(adata, color="cell_type", title="SCVI") + + +@pytest.mark.custom_dataloader +def test_lamindb_custom_dataloader(save_path="."): + # initialize a local lamin database + os.system("lamin init --storage ~/scdataloader2 --schema bionty") + # os.system("lamin close") + # os.system("lamin load scdataloader") + + # local branch with fix only for this test + import sys + + # should be ready for importing the cloned branch on a remote machine that runs github action + sys.path.insert( + 0, + "/home/runner/work/scvi-tools/scvi-tools/" "scDataLoader/", + ) + sys.path.insert(0, "src") + import lamindb as ln + import tqdm + from scdataloader import Collator, DataModule, SimpleAnnDataset + + # import bionty as bt + # from scdataloader import utils + # from scdataloader.preprocess import ( + # LaminPreprocessor, + # additional_postprocess, + # additional_preprocess, + # ) + # import numpy as np + # import tiledbsoma as soma + from scdataloader.utils import populate_my_ontology + from torch.utils.data import DataLoader + # from scdataloader.base import NAME + # from cellxgene_census.experimental.ml import experiment_dataloader + + # populate_my_ontology() #to populate everything (recommended) (can take 2-10mns) + + populate_my_ontology( + organisms=["NCBITaxon:10090", "NCBITaxon:9606"], + sex=["PATO:0000384", "PATO:0000383"], + ) + + # preprocess datasets - do we need this part? + # DESCRIPTION = "preprocessed by scDataLoader" + + cx_dataset = ( + ln.Collection.using(instance="laminlabs/cellxgene") + .filter(name="cellxgene-census", version="2023-12-15") + .one() + ) + cx_dataset, len(cx_dataset.artifacts.all()) + + # do_preprocess = LaminPreprocessor( + # additional_postprocess=additional_postprocess, + # additional_preprocess=additional_preprocess, + # skip_validate=True, + # subset_hvg=0, + # ) + + # preprocessed_dataset = do_preprocess( + # cx_dataset, name=DESCRIPTION, description=DESCRIPTION, start_at=1, version="2" + # ) + + # create dataloaders + + datamodule = DataModule( + collection_name="preprocessed dataset", + organisms=["NCBITaxon:9606"], # organism that we will work on + how="most expr", # for the collator (most expr genes only will be selected) / "some" + max_len=1000, # only the 1000 most expressed + batch_size=64, + num_workers=1, + validation_split=0.1, + test_split=0, + ) + + # we setup the datamodule (as exemplified in lightning's good practices, b + # ut there might be some things to improve here) + # testfiles = datamodule.setup() + + for i in tqdm.tqdm(datamodule.train_dataloader()): + # pass #or do pass + print(i) + break + + # with lightning: + # Trainer(model, datamodule) + + # Read adata and create lamindb dataloader + adata_orig = sc.read_h5ad( + "/Users/orikr/PycharmProjects/scvi-tools/scDataLoader/tests/test.h5ad" + ) + # preprocessor = Preprocessor(do_postp=False) + # adata = preprocessor(adata_orig) + adataset = SimpleAnnDataset(adata_orig, obs_to_output=["organism_ontology_term_id"]) + col = Collator( + organisms=["NCBITaxon:9606"], + max_len=1000, + how="random expr", + ) + dataloader = DataLoader( + adataset, + collate_fn=col, + batch_size=4, + num_workers=1, + shuffle=False, + ) + + # We will now create the SCVI model object: + # Its parameters: + # n_layers = 1 + # n_latent = 10 + # batch_size = 1024 + # train_size = 0.9 + # max_epochs = 1 + + # def on_before_batch_transfer( + # batch: tuple[torch.Tensor, torch.Tensor], + # ) -> dict[str, torch.Tensor | None]: + # """Format the datapipe output with registry keys for scvi-tools.""" + # X, obs = batch + # X_KEY: str = "X" + # BATCH_KEY: str = "batch" + # LABELS_KEY: str = "labels" + # return { + # X_KEY: X, + # BATCH_KEY: obs, + # LABELS_KEY: None, + # } + + # Try the lamindb dataloder on a trained scvi-model with adata + # adata = adata.copy() + scvi.model.SCVI.setup_anndata(adata_orig, batch_key="cell_type_ontology_term_id") + model = scvi.model.SCVI(adata_orig, n_latent=10) + model.train(max_epochs=1) + # dataloader2 = experiment_dataloader(dataloader, num_workers=0, persistent_workers=False) + # mapped_dataloader = ( + # on_before_batch_transfer(tensor, None) for tensor in dataloader2 + # ) + # dataloader = model._make_data_loader(mapped_dataloader) + _ = model.get_elbo(dataloader=dataloader) + _ = model.get_marginal_ll(dataloader=dataloader) + _ = model.get_reconstruction_error(dataloader=dataloader) + _ = model.get_latent_representation(dataloader=dataloader) + + # scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time + # model_lamindb = scvi.model.SCVI( + # registry=datamodule.registry, + # n_layers=n_layers, + # n_latent=n_latent, + # gene_likelihood="nb", + # encode_covariates=False, + # ) + # + # pprint(datamodule.registry) + # + # model_lamindb.train( + # datamodule=datamodule, + # max_epochs=max_epochs, + # batch_size=batch_size, + # train_size=train_size, + # early_stopping=False, + # ) + # We have to create a registry without setup_anndata that contains the same elements + # The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK + # need to pass here new object of registry taht contains everything we will need From 1801604759dbac8af01cfa8401500c30e31c20e4 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 10 Oct 2024 16:26:19 +0300 Subject: [PATCH 47/53] fix the c.dataloders test --- .github/workflows/test_linux_custom_dataloader.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 53d6de3353..7aec2f1c67 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -76,7 +76,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest tests/dataloaders/test_czi_custom_dataloader.py -v --color=yes --custom-dataloader-tests + coverage run -m pytest tests/dataloaders/test_custom_dataloader.py -v --color=yes --custom-dataloader-tests coverage report - uses: codecov/codecov-action@v4 From ed77a650254928cfbc09e89678bcea823beed90a Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 10 Oct 2024 16:31:52 +0300 Subject: [PATCH 48/53] fix the c.dataloders test --- .github/workflows/test_linux_custom_dataloader.yml | 4 +++- tests/dataloaders/test_custom_dataloader.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 7aec2f1c67..20f8645079 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -35,7 +35,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: ["3.11"] + python: ["3.10"] name: integration @@ -60,6 +60,8 @@ jobs: python -m pip install s3fs python -m pip install torchdata python -m pip install psutil + python -m pip install cellxgene-census + python -m pip install lamindb - name: Install Specific Branch of Repository env: diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index b8034bb6af..50ad7d6665 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from pprint import pprint @@ -11,7 +13,7 @@ @pytest.mark.custom_dataloader -def test_czi_custom_dataloader(save_path="."): +def test_czi_custom_dataloader(save_path): # local branch with fix only for this test import sys @@ -235,7 +237,7 @@ def test_czi_custom_dataloader(save_path="."): @pytest.mark.custom_dataloader -def test_lamindb_custom_dataloader(save_path="."): +def test_lamindb_custom_dataloader(save_path): # initialize a local lamin database os.system("lamin init --storage ~/scdataloader2 --schema bionty") # os.system("lamin close") From fc831d57e52dce9d89a6405552f42c8bf7c8e420 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 10 Oct 2024 16:39:15 +0300 Subject: [PATCH 49/53] fix the c.dataloders test --- .github/workflows/test_linux_custom_dataloader.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 20f8645079..42f29f9062 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -35,7 +35,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: ["3.10"] + python: ["3.11"] name: integration @@ -62,13 +62,14 @@ jobs: python -m pip install psutil python -m pip install cellxgene-census python -m pip install lamindb + python -m pip install bionty - name: Install Specific Branch of Repository env: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/chanzuckerberg/cellxgene-census.git + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git - name: Run specific custom dataloader pytest From 7400621c6d363537df36429ff99e10c69a9ded85 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 10 Oct 2024 16:56:39 +0300 Subject: [PATCH 50/53] fix the c.dataloders test --- .github/workflows/test_linux_custom_dataloader.yml | 4 +++- tests/dataloaders/test_custom_dataloader.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 42f29f9062..21c42fa3e6 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -56,13 +56,15 @@ jobs: run: | python -m pip install --upgrade pip wheel uv python -m uv pip install --system "scvi-tools[tests] @ ." + python -m pip install scdataloader + python -m pip install cellxgene-census python -m pip install tiledbsoma python -m pip install s3fs python -m pip install torchdata python -m pip install psutil - python -m pip install cellxgene-census python -m pip install lamindb python -m pip install bionty + python -m pip install biomart - name: Install Specific Branch of Repository env: diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 50ad7d6665..b2c49b9afe 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -211,9 +211,9 @@ def test_czi_custom_dataloader(save_path): mapped_dataloader = ( datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader ) - _ = model_census.get_elbo(dataloader=mapped_dataloader) - _ = model_census.get_marginal_ll(dataloader=mapped_dataloader) - _ = model_census.get_reconstruction_error(dataloader=mapped_dataloader) + # _ = model_census.get_elbo(dataloader=mapped_dataloader) + # _ = model_census.get_marginal_ll(dataloader=mapped_dataloader) + # _ = model_census.get_reconstruction_error(dataloader=mapped_dataloader) latent = model_census.get_latent_representation(dataloader=mapped_dataloader) emb_idx = datapipe._obs_joinids From 47376ca07dae23e7f80157e55adf4236518286e0 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 10 Oct 2024 17:13:05 +0300 Subject: [PATCH 51/53] fix the c.dataloders test --- tests/dataloaders/test_custom_dataloader.py | 33 ++++++++++----------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index b2c49b9afe..de68533362 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -258,11 +258,12 @@ def test_lamindb_custom_dataloader(save_path): # import bionty as bt # from scdataloader import utils - # from scdataloader.preprocess import ( - # LaminPreprocessor, - # additional_postprocess, - # additional_preprocess, - # ) + from scdataloader.preprocess import ( + LaminPreprocessor, + additional_postprocess, + additional_preprocess, + ) + # import numpy as np # import tiledbsoma as soma from scdataloader.utils import populate_my_ontology @@ -278,7 +279,7 @@ def test_lamindb_custom_dataloader(save_path): ) # preprocess datasets - do we need this part? - # DESCRIPTION = "preprocessed by scDataLoader" + DESCRIPTION = "preprocessed by scDataLoader" cx_dataset = ( ln.Collection.using(instance="laminlabs/cellxgene") @@ -287,16 +288,14 @@ def test_lamindb_custom_dataloader(save_path): ) cx_dataset, len(cx_dataset.artifacts.all()) - # do_preprocess = LaminPreprocessor( - # additional_postprocess=additional_postprocess, - # additional_preprocess=additional_preprocess, - # skip_validate=True, - # subset_hvg=0, - # ) + do_preprocess = LaminPreprocessor( + additional_postprocess=additional_postprocess, + additional_preprocess=additional_preprocess, + skip_validate=True, + subset_hvg=0, + ) - # preprocessed_dataset = do_preprocess( - # cx_dataset, name=DESCRIPTION, description=DESCRIPTION, start_at=1, version="2" - # ) + do_preprocess(cx_dataset, name=DESCRIPTION, description=DESCRIPTION, start_at=1, version="2") # create dataloaders @@ -324,9 +323,7 @@ def test_lamindb_custom_dataloader(save_path): # Trainer(model, datamodule) # Read adata and create lamindb dataloader - adata_orig = sc.read_h5ad( - "/Users/orikr/PycharmProjects/scvi-tools/scDataLoader/tests/test.h5ad" - ) + adata_orig = sc.read_h5ad("./scDataLoader/tests/test.h5ad") # preprocessor = Preprocessor(do_postp=False) # adata = preprocessor(adata_orig) adataset = SimpleAnnDataset(adata_orig, obs_to_output=["organism_ontology_term_id"]) From f94f7faaf692a801ee78fe559a390b998129b0cd Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Sun, 13 Oct 2024 15:06:45 +0300 Subject: [PATCH 52/53] removed redundat functions in code base --- src/scvi/model/_scvi.py | 17 ++++------------- src/scvi/model/base/_base_model.py | 17 ----------------- tests/dataloaders/test_custom_dataloader.py | 1 - 3 files changed, 4 insertions(+), 31 deletions(-) diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 3b9b7708a6..ebb51fa97e 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -7,7 +7,7 @@ import scvi from scvi import REGISTRY_KEYS -from scvi.data import AnnDataManager, _constants +from scvi.data import AnnDataManager from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( @@ -23,7 +23,7 @@ from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin from scvi.model.utils import get_minified_adata_scrna from scvi.module import VAE -from scvi.utils import attrdict, setup_anndata_dsp +from scvi.utils import setup_anndata_dsp from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin @@ -229,19 +229,11 @@ def setup_anndata( adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - @staticmethod - def _get_summary_stats_from_registry(registry: dict) -> attrdict: - summary_stats = {} - for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): - field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] - summary_stats.update(field_summary_stats) - return attrdict(summary_stats) - @classmethod @setup_anndata_dsp.dedent def setup_datamodule( cls, - datamodule, # TODO: what to put here? + datamodule, source_registry=None, layer: str | None = None, batch_key: list[str] | None = None, @@ -263,7 +255,6 @@ def setup_datamodule( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ - # TODO: from adata (czi)? if datamodule.__class__.__name__ == "CensusSCVIDataModule": # CZI categorical_mapping = datamodule.datapipe.obs_encoders["batch"].classes_ @@ -300,7 +291,7 @@ def setup_datamodule( "state_registry": { "n_obs": datamodule.n_obs, "n_vars": datamodule.n_vars, - "column_names": [str(i) for i in column_names], # TODO: from adata (czi)? + "column_names": [str(i) for i in column_names], }, "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, }, diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 3af9368d5c..bb74f17715 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -299,23 +299,6 @@ def data_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: else: return self._adata_manager.get_from_registry(registry_key) - # def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: - # """Returns the object in AnnData associated with the key in the data registry. - # - # Parameters - # ---------- - # registry_key - # key of object to get from ``self.data_registry`` - # - # Returns - # ------- - # The requested data. - # """ - # if not self.adata: - # raise ValueError("self.adata is None. Please registry AnnData object.") - # else: - # return self._adata_manager.get_from_registry(registry_key) - def deregister_manager(self, adata: AnnData | None = None): """Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`. diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index de68533362..28916a4be7 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -81,7 +81,6 @@ def test_czi_custom_dataloader(save_path): model = scvi.model.SCVI(adata_orig, n_latent=10) model.train(max_epochs=1) - # TODO: do we need to apply those functions to any census model as is? dataloader = model._make_data_loader(adata_orig) _ = model.get_elbo(dataloader=dataloader) _ = model.get_marginal_ll(dataloader=dataloader) From 962f0431b53200dac32985c86b6c9cb9ab900798 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 15 Oct 2024 18:14:50 +0300 Subject: [PATCH 53/53] Added scanvi support, including CZI datamodule fix for it --- src/scvi/model/_scanvi.py | 126 ++++++++++++++- src/scvi/model/_scvi.py | 7 +- tests/dataloaders/test_custom_dataloader.py | 167 +++++++++++++++----- 3 files changed, 254 insertions(+), 46 deletions(-) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index d7e343cfd2..630f44f3cb 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -10,6 +10,7 @@ import torch from anndata import AnnData +import scvi from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data._constants import ( @@ -44,6 +45,7 @@ from typing import Literal from anndata import AnnData + from lightning import LightningDataModule from scvi._types import MinifiedDataType from scvi.data.fields import ( @@ -127,12 +129,13 @@ def __init__( dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", linear_classifier: bool = False, + datamodule: LightningDataModule | None = None, **model_kwargs, ): super().__init__(adata, registry) scanvae_model_kwargs = dict(model_kwargs) - self._set_indices_and_labels() + self._set_indices_and_labels(datamodule) # ignores unlabeled catgegory n_labels = self.summary_stats.n_labels - 1 @@ -268,17 +271,21 @@ def from_scvi_model( return scanvi_model - def _set_indices_and_labels(self): + def _set_indices_and_labels(self, datamodule=None): """Set indices for labeled and unlabeled cells.""" labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key self.unlabeled_category_ = labels_state_registry.unlabeled_category - labels = get_anndata_attribute( - self.adata, - self.adata_manager.data_registry.labels.attr_name, - self.original_label_key, - ).ravel() + if datamodule is None: + labels = get_anndata_attribute( + self.adata, + self.adata_manager.data_registry.labels.attr_name, + self.original_label_key, + ).ravel() + else: + # for CZI: + labels = list(datamodule.datapipe.map(lambda x: x["label"])) self._label_mapping = labels_state_registry.categorical_mapping # set unlabeled and labeled indices @@ -500,6 +507,111 @@ def setup_anndata( adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) + @classmethod + @setup_anndata_dsp.dedent + def setup_datamodule( + cls, + datamodule: LightningDataModule | None = None, + source_registry=None, + layer: str | None = None, + batch_key: list[str] | None = None, + labels_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + **kwargs, + ): + """%(summary)s. + + Parameters + ---------- + %(param_datamodule)s + %(param_source_registry)s + %(param_layer)s + %(param_batch_key)s + %(param_size_factor_key)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + """ + if datamodule.__class__.__name__ == "CensusSCVIDataModule": + # CZI + batch_mapping = datamodule.datapipe.obs_encoders["batch"].classes_ + labels_mapping = datamodule.datapipe.obs_encoders["label"].classes_ + features_names = list( + datamodule.datapipe.var_query.coords[0] + if datamodule.datapipe.var_query is not None + else range(datamodule.n_vars) + ) + n_batch = datamodule.n_batch + n_label = datamodule.n_label + + else: + # Anndata -> CZI + # if we are here and datamodule is actually an AnnData object + # it means we init the custom dataloder model with anndata + batch_mapping = source_registry["field_registries"]["batch"]["state_registry"][ + "categorical_mapping" + ] + labels_mapping = source_registry["field_registries"]["label"]["state_registry"][ + "categorical_mapping" + ] + features_names = datamodule.var.soma_joinid.values + n_batch = source_registry["field_registries"]["batch"]["summary_stats"]["n_batch"] + n_label = 1 # need to change + + datamodule.registry = { + "scvi_version": scvi.__version__, + "model_name": "SCVI", + "setup_args": { + "layer": layer, + "batch_key": batch_key, + "labels_key": labels_key, + "size_factor_key": size_factor_key, + "categorical_covariate_keys": categorical_covariate_keys, + "continuous_covariate_keys": continuous_covariate_keys, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": datamodule.n_obs, + "n_vars": datamodule.n_vars, + "column_names": [str(i) for i in features_names], + }, + "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": batch_mapping, + "original_key": "batch", + }, + "summary_stats": {"n_batch": n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": labels_mapping, + "original_key": "label", + "unlabeled_category": datamodule.unlabeled_category, + }, + "summary_stats": {"n_labels": n_label}, + }, + "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } + @staticmethod def _get_fields_for_adata_minification( minified_data_type: MinifiedDataType, diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index ebb51fa97e..d70710e388 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -31,6 +31,7 @@ from typing import Literal from anndata import AnnData + from lightning import LightningDataModule from scvi._types import MinifiedDataType from scvi.data.fields import ( @@ -127,6 +128,7 @@ def __init__( dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb", latent_distribution: Literal["normal", "ln"] = "normal", + datamodule: LightningDataModule | None = None, **kwargs, ): super().__init__(adata, registry) @@ -233,7 +235,7 @@ def setup_anndata( @setup_anndata_dsp.dedent def setup_datamodule( cls, - datamodule, + datamodule: LightningDataModule | None = None, source_registry=None, layer: str | None = None, batch_key: list[str] | None = None, @@ -247,7 +249,8 @@ def setup_datamodule( Parameters ---------- - %(param_adata)s + %(param_datamodule)s + %(param_source_registry)s %(param_layer)s %(param_batch_key)s %(param_labels_key)s diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py index 28916a4be7..194bfc9dec 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,22 +1,23 @@ from __future__ import annotations import os +import sys from pprint import pprint import numpy as np import pandas as pd import pytest import scanpy as sc +import tqdm import scvi from scvi.data import synthetic_iid +# import numpy as np -@pytest.mark.custom_dataloader -def test_czi_custom_dataloader(save_path): - # local branch with fix only for this test - import sys +@pytest.mark.custom_dataloader +def test_czi_custom_dataloader_scvi(save_path="."): # should be ready for importing the cloned branch on a remote machine that runs github action sys.path.insert( 0, @@ -78,7 +79,7 @@ def test_czi_custom_dataloader(save_path): adata_orig = synthetic_iid() scvi.model.SCVI.setup_anndata(adata_orig, batch_key="batch") - model = scvi.model.SCVI(adata_orig, n_latent=10) + model = scvi.model.SCVI(adata_orig) model.train(max_epochs=1) dataloader = model._make_data_loader(adata_orig) @@ -93,29 +94,20 @@ def test_czi_custom_dataloader(save_path): user_attributes = model._get_user_attributes() pprint(user_attributes) - n_layers = 1 - n_latent = 50 - scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time model_census = scvi.model.SCVI( registry=datamodule.registry, - n_layers=n_layers, - n_latent=n_latent, gene_likelihood="nb", encode_covariates=False, ) pprint(datamodule.registry) - batch_size = 1024 - train_size = 0.9 max_epochs = 1 model_census.train( datamodule=datamodule, max_epochs=max_epochs, - batch_size=batch_size, - train_size=train_size, early_stopping=False, ) @@ -137,8 +129,6 @@ def test_czi_custom_dataloader(save_path): model_census2.train( datamodule=datamodule, max_epochs=max_epochs, - batch_size=batch_size, - train_size=train_size, early_stopping=False, ) @@ -173,8 +163,6 @@ def test_czi_custom_dataloader(save_path): model_census3.train( datamodule=datamodule, max_epochs=max_epochs, - batch_size=batch_size, - train_size=train_size, early_stopping=False, ) @@ -236,15 +224,126 @@ def test_czi_custom_dataloader(save_path): @pytest.mark.custom_dataloader -def test_lamindb_custom_dataloader(save_path): +def test_czi_custom_dataloader_scanvi(save_path="."): + # should be ready for importing the cloned branch on a remote machine that runs github action + sys.path.insert( + 0, + "/home/runner/work/scvi-tools/scvi-tools/" + "cellxgene-census/api/python/cellxgene_census/src", + ) + sys.path.insert(0, "src") + import cellxgene_census + import tiledbsoma as soma + from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule + # from cellxgene_census.experimental.pp import highly_variable_genes + + # this test checks the local custom dataloder made by CZI and run several tests with it + census = cellxgene_census.open_soma(census_version="stable") + + experiment_name = "mus_musculus" + obs_value_filter = ( + 'is_primary_data == True and tissue_general in ["kidney","liver"] and nnz >= 3000' + ) + + # This is under comments just to save time (selecting highly varkable genes): + # top_n_hvg = 8000 + # hvg_batch = ["assay", "suspension_type"] + # + # # For HVG, we can use the `highly_variable_genes` function provided in the Census, + # # which can compute HVGs in constant memory: + # + # query = census["census_data"][experiment_name].axis_query( + # measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter) + # ) + # hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch) + # + # hv = hvgs_df.highly_variable + # hv_idx = hv[hv].index + + hv_idx = np.arange(100) # just ot make it smaller and faster for debug + + # this is CZI part to be taken once all is ready + batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"] + label_keys = ["tissue_general"] + datamodule = CensusSCVIDataModule( + census["census_data"][experiment_name], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + batch_size=1024, + shuffle=True, + batch_keys=batch_keys, + label_keys=label_keys, + unlabeled_category="label_0", + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, + ) + + # table of genes should be filtered by soma_joinid - but we should keep the encoded indexes + # This is nice to have and might be uses in the downstream analysis + # var_df = census["census_data"][experiment_name].ms["RNA"].var.read().concat().to_pandas() + # var_df = var_df.loc[var_df.soma_joinid.isin( + # list(datamodule.datapipe.var_query.coords[0] if datamodule.datapipe.var_query is not None + # else range(datamodule.n_vars)))] + + # scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time + # model_census = scvi.model.SCVI( + # registry=datamodule.registry, + # gene_likelihood="nb", + # encode_covariates=False, + # ) + # + # pprint(datamodule.registry) + # + max_epochs = 1 + # + # model_census.train( + # datamodule=datamodule, + # max_epochs=max_epochs, + # early_stopping=False, + # ) + + scvi.model.SCANVI.setup_datamodule(datamodule) + pprint(datamodule.registry) + model = scvi.model.SCANVI(registry=datamodule.registry, datamodule=datamodule) + model.view_anndata_setup(datamodule) + adata_manager = model.adata_manager + pprint(adata_manager.registry) + model.train( + datamodule=datamodule, max_epochs=max_epochs, train_size=0.5, check_val_every_n_epoch=1 + ) + # logged_keys = model.history.keys() + # assert len(model._labeled_indices) == sum(adata.obs["labels"] != "label_0") + # assert len(model._unlabeled_indices) == sum(adata.obs["labels"] == "label_0") + # assert "elbo_validation" in logged_keys + # assert "reconstruction_loss_validation" in logged_keys + # assert "kl_local_validation" in logged_keys + # assert "elbo_train" in logged_keys + # assert "reconstruction_loss_train" in logged_keys + # assert "kl_local_train" in logged_keys + # assert "validation_classification_loss" in logged_keys + # assert "validation_accuracy" in logged_keys + # assert "validation_f1_score" in logged_keys + # assert "validation_calibration_error" in logged_keys + # adata2 = synthetic_iid() + # predictions = model.predict(adata2, indices=[1, 2, 3]) + # assert len(predictions) == 3 + # model.predict() + # df = model.predict(adata2, soft=True) + # assert isinstance(df, pd.DataFrame) + # model.predict(adata2, soft=True, indices=[1, 2, 3]) + # model.get_normalized_expression(adata2) + # model.differential_expression(groupby="labels", group1="label_1") + # model.differential_expression(groupby="labels", group1="label_1", group2="label_2") + + +@pytest.mark.custom_dataloader +def test_scdataloader_custom_dataloader_scvi(save_path="."): # initialize a local lamin database os.system("lamin init --storage ~/scdataloader2 --schema bionty") # os.system("lamin close") # os.system("lamin load scdataloader") - # local branch with fix only for this test - import sys - # should be ready for importing the cloned branch on a remote machine that runs github action sys.path.insert( 0, @@ -252,7 +351,6 @@ def test_lamindb_custom_dataloader(save_path): ) sys.path.insert(0, "src") import lamindb as ln - import tqdm from scdataloader import Collator, DataModule, SimpleAnnDataset # import bionty as bt @@ -263,7 +361,6 @@ def test_lamindb_custom_dataloader(save_path): additional_preprocess, ) - # import numpy as np # import tiledbsoma as soma from scdataloader.utils import populate_my_ontology from torch.utils.data import DataLoader @@ -340,13 +437,6 @@ def test_lamindb_custom_dataloader(save_path): ) # We will now create the SCVI model object: - # Its parameters: - # n_layers = 1 - # n_latent = 10 - # batch_size = 1024 - # train_size = 0.9 - # max_epochs = 1 - # def on_before_batch_transfer( # batch: tuple[torch.Tensor, torch.Tensor], # ) -> dict[str, torch.Tensor | None]: @@ -360,12 +450,13 @@ def test_lamindb_custom_dataloader(save_path): # BATCH_KEY: obs, # LABELS_KEY: None, # } + max_epochs = 1 # Try the lamindb dataloder on a trained scvi-model with adata # adata = adata.copy() scvi.model.SCVI.setup_anndata(adata_orig, batch_key="cell_type_ontology_term_id") - model = scvi.model.SCVI(adata_orig, n_latent=10) - model.train(max_epochs=1) + model = scvi.model.SCVI(adata_orig) + model.train(max_epochs=max_epochs) # dataloader2 = experiment_dataloader(dataloader, num_workers=0, persistent_workers=False) # mapped_dataloader = ( # on_before_batch_transfer(tensor, None) for tensor in dataloader2 @@ -379,8 +470,6 @@ def test_lamindb_custom_dataloader(save_path): # scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time # model_lamindb = scvi.model.SCVI( # registry=datamodule.registry, - # n_layers=n_layers, - # n_latent=n_latent, # gene_likelihood="nb", # encode_covariates=False, # ) @@ -390,10 +479,14 @@ def test_lamindb_custom_dataloader(save_path): # model_lamindb.train( # datamodule=datamodule, # max_epochs=max_epochs, - # batch_size=batch_size, - # train_size=train_size, # early_stopping=False, # ) # We have to create a registry without setup_anndata that contains the same elements # The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK # need to pass here new object of registry taht contains everything we will need + + +@pytest.mark.custom_dataloader +def test_lamindb_custom_dataloader_scvi(save_path="."): + # a test for mapped collection + return