Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 15, 2024
1 parent e48fe3b commit 609785f
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 74 deletions.
16 changes: 5 additions & 11 deletions src/scvi/external/decipher/_components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from collections.abc import Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(

# The multiple outputs are computed as a single output layer, and then split
indices = np.concatenate(([0], np.cumsum(self.output_dims)))
self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:])]
self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)]

# Create masked layers
deep_context_dim = self.context_dim if self.deep_context_injection else 0
Expand All @@ -63,21 +63,15 @@ def __init__(
batch_norms.append(nn.BatchNorm1d(hidden_dims[0]))
for i in range(1, len(hidden_dims)):
layers.append(
torch.nn.Linear(
hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]
)
torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i])
)
batch_norms.append(nn.BatchNorm1d(hidden_dims[i]))

layers.append(
torch.nn.Linear(
hidden_dims[-1] + deep_context_dim, self.output_total_dim
)
torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim)
)
else:
layers.append(
torch.nn.Linear(input_dim + context_dim, self.output_total_dim)
)
layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim))

self.layers = torch.nn.ModuleList(layers)

Expand Down
8 changes: 2 additions & 6 deletions src/scvi/external/decipher/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def setup_anndata(
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

Expand Down Expand Up @@ -113,9 +111,7 @@ def get_latent_representation(
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent_locs = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
Expand Down
12 changes: 3 additions & 9 deletions src/scvi/external/decipher/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def device(self):
return self._dummy_param.device

@staticmethod
def _get_fn_args_from_batch(
tensor_dict: dict[str, torch.Tensor]
) -> Iterable | dict:
def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict:
x = tensor_dict[REGISTRY_KEYS.X_KEY]
return (x,), {}

Expand Down Expand Up @@ -125,9 +123,7 @@ def model(self, x: torch.Tensor):
self.theta + self._epsilon
)
# noinspection PyUnresolvedReferences
x_dist = dist.NegativeBinomial(
total_count=self.theta + self._epsilon, logits=logit
)
x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit)
pyro.sample("x", x_dist.to_event(1), obs=x)

@auto_move_data
Expand Down Expand Up @@ -188,9 +184,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5):
model_trace = poutine.trace(
poutine.replay(self.model, trace=guide_trace)
).get_trace(x)
log_weights.append(
model_trace.log_prob_sum() - guide_trace.log_prob_sum()
)
log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum())

finally:
self.beta = old_beta
Expand Down
4 changes: 1 addition & 3 deletions src/scvi/external/decipher/_trainingplan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def __init__(
optim_kwargs.update({"lr": 5e-3})
if "weight_decay" not in optim_kwargs.keys():
optim_kwargs.update({"weight_decay": 1e-4})
self.optim = (
pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim
)
self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim
# We let SVI take care of all optimization
self.automatic_optimization = False

Expand Down
61 changes: 16 additions & 45 deletions src/scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ def __init__(
self.optimizer_creator = optimizer_creator

if self.optimizer_name == "Custom" and self.optimizer_creator is None:
raise ValueError(
"If optimizer is 'Custom', `optimizer_creator` must be provided."
)
raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.")

self._n_obs_training = None
self._n_obs_validation = None
Expand Down Expand Up @@ -221,9 +219,7 @@ def initialize_train_metrics(self):
self.kl_local_train,
self.kl_global_train,
self.train_metrics,
) = self._create_elbo_metric_components(
mode="train", n_total=self.n_obs_training
)
) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training)
self.elbo_train.reset()

def initialize_val_metrics(self):
Expand All @@ -234,9 +230,7 @@ def initialize_val_metrics(self):
self.kl_local_val,
self.kl_global_val,
self.val_metrics,
) = self._create_elbo_metric_components(
mode="validation", n_total=self.n_obs_validation
)
) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation)
self.elbo_val.reset()

@property
Expand Down Expand Up @@ -372,9 +366,7 @@ def validation_step(self, batch, batch_idx):
)
self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation")

def _optimizer_creator_fn(
self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW
):
def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW):
"""Create optimizer for the model.
This type of function can be passed as the `optimizer_creator`
Expand Down Expand Up @@ -552,9 +544,7 @@ def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True):
if predict_true_class:
cls_target = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes)
else:
one_hot_batch = torch.nn.functional.one_hot(
batch_index.squeeze(-1), n_classes
)
one_hot_batch = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes)
# place zeroes where true label is
cls_target = (~one_hot_batch.bool()).float()
cls_target = cls_target / (n_classes - 1)
Expand Down Expand Up @@ -582,9 +572,7 @@ def training_step(self, batch, batch_idx):
else:
opt1, opt2 = opts

inference_outputs, _, scvi_loss = self.forward(
batch, loss_kwargs=self.loss_kwargs
)
inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
z = inference_outputs["z"]
loss = scvi_loss.loss
# fool classifier if doing adversarial training
Expand Down Expand Up @@ -617,10 +605,7 @@ def on_train_epoch_end(self):

def on_validation_epoch_end(self) -> None:
"""Update the learning rate via scheduler steps."""
if (
not self.reduce_lr_on_plateau
or "validation" not in self.lr_scheduler_metric
):
if not self.reduce_lr_on_plateau or "validation" not in self.lr_scheduler_metric:
return
else:
sch = self.lr_schedulers()
Expand Down Expand Up @@ -651,9 +636,7 @@ def configure_optimizers(self):
)

if self.adversarial_classifier is not False:
params2 = filter(
lambda p: p.requires_grad, self.adversarial_classifier.parameters()
)
params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters())
optimizer2 = torch.optim.Adam(
params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay
)
Expand Down Expand Up @@ -919,9 +902,7 @@ def __init__(
self.n_epochs_kl_warmup = n_epochs_kl_warmup
self.use_kl_weight = False
if isinstance(self.module.model, PyroModule):
self.use_kl_weight = (
"kl_weight" in signature(self.module.model.forward).parameters
)
self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters
elif callable(self.module.model):
self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters
self.scale_elbo = scale_elbo
Expand Down Expand Up @@ -1102,9 +1083,7 @@ def __init__(
optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {}
if "lr" not in optim_kwargs.keys():
optim_kwargs.update({"lr": 1e-3})
self.optim = (
pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim
)
self.optim = pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim
# We let SVI take care of all optimization
self.automatic_optimization = False

Expand Down Expand Up @@ -1200,9 +1179,7 @@ def __init__(
self.loss_fn = loss()

if self.module.logits is False and loss == torch.nn.CrossEntropyLoss:
raise UserWarning(
"classifier should return logits when using CrossEntropyLoss."
)
raise UserWarning("classifier should return logits when using CrossEntropyLoss.")

def forward(self, *args, **kwargs):
"""Passthrough to the module's forward function."""
Expand Down Expand Up @@ -1232,9 +1209,7 @@ def configure_optimizers(self):
optim_cls = torch.optim.AdamW
else:
raise ValueError("Optimizer not understood.")
optimizer = optim_cls(
params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay
)
optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay)

return optimizer

Expand Down Expand Up @@ -1300,11 +1275,7 @@ def __init__(

def get_optimizer_creator(self) -> JaxOptimizerCreator:
"""Get optimizer creator for the model."""
clip_by = (
optax.clip_by_global_norm(self.max_norm)
if self.max_norm
else optax.identity()
)
clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity()
if self.optimizer_name == "Adam":
# Replicates PyTorch Adam defaults
optim = optax.chain(
Expand Down Expand Up @@ -1358,9 +1329,9 @@ def loss_fn(params):
loss = loss_output.loss
return loss, (loss_output, new_model_state)

(loss, (loss_output, new_model_state)), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state.params)
(loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state.params
)
new_state = state.apply_gradients(grads=grads, state=new_model_state)
return new_state, loss, loss_output

Expand Down

0 comments on commit 609785f

Please sign in to comment.