scvi.external.contrastivevi.ContrastiveVAE#

class scvi.external.contrastivevi.ContrastiveVAE(n_input, n_batch=0, n_hidden=128, n_background_latent=10, n_salient_latent=10, n_layers=1, dropout_rate=0.1, use_observed_lib_size=True, library_log_means=None, library_log_vars=None, wasserstein_penalty=0)[source]#

Bases: BaseModuleClass

Variational inference for contrastive analysis of RNA-seq data.

Implements the contrastiveVI model of [Weinberger et al., 2023].

Parameters:
  • n_input (int) – Number of input genes.

  • n_batch (int (default: 0)) – Number of batches. If 0, no batch effect correction is performed.

  • n_hidden (int (default: 128)) – Number of nodes per hidden layer.

  • n_background_latent (int (default: 10)) – Dimensionality of the background latent space.

  • n_salient_latent (int (default: 10)) – Dimensionality of the salient latent space.

  • n_layers (int (default: 1)) – Number of hidden layers used for encoder and decoder NNs.

  • dropout_rate (float (default: 0.1)) – Dropout rate for neural networks.

  • use_observed_lib_size (bool (default: True)) – Use observed library size for RNA as scaling factor in mean of conditional distribution.

  • library_log_means (Optional[ndarray] (default: None)) – 1 x n_batch array of means of the log library sizes. Parameterize prior on library size if not using observed library size.

  • library_log_vars (Optional[ndarray] (default: None)) – 1 x n_batch array of variances of the log library sizes. Parameterize prior on library size if not using observed library size.

  • wasserstein_penalty (float (default: 0)) – Weight of the Wasserstein distance loss that further discourages shared variations from leaking into the salient latent space.

Attributes table#

training

Methods table#

generative(background, target)

Run the generative model.

inference(background, target[, n_samples])

Run the recognition model.

latent_kl_divergence(variational_mean, ...)

Computes KL divergence between a variational posterior and prior Gaussian.

library_kl_divergence(batch_index, ...)

Computes KL divergence between library size variational posterior and prior.

loss(concat_tensors, inference_outputs, ...)

Computes loss terms for contrastiveVI.

reconstruction_loss(x, px_rate, px_r, px_dropout)

Computes likelihood loss for zero-inflated negative binomial distribution.

sample()

Generate samples from the learned model.

Attributes#

ContrastiveVAE.training: bool#

Methods#

ContrastiveVAE.generative(background, target)[source]#

Run the generative model.

This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).

This function should return a dictionary with str keys and Tensor values.

Return type:

dict[str, dict[str, Tensor]]

ContrastiveVAE.inference(background, target, n_samples=1)[source]#

Run the recognition model.

In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.

This function should return a dictionary with str keys and Tensor values.

Return type:

dict[str, dict[str, Tensor]]

static ContrastiveVAE.latent_kl_divergence(variational_mean, variational_var, prior_mean, prior_var)[source]#

Computes KL divergence between a variational posterior and prior Gaussian.

Parameters:
  • variational_mean (Tensor) – Mean of the variational posterior Gaussian.

  • variational_var (Tensor) – Variance of the variational posterior Gaussian.

  • prior_mean (Tensor) – Mean of the prior Gaussian.

  • prior_var (Tensor) – Variance of the prior Gaussian.

Return type:

Tensor

Returns:

KL divergence for each data point. If number of latent samples == 1, the tensor has shape (batch_size, ). If number of latent samples > 1, the tensor has shape (n_samples, batch_size).

ContrastiveVAE.library_kl_divergence(batch_index, variational_library_mean, variational_library_var, library)[source]#

Computes KL divergence between library size variational posterior and prior.

Both the variational posterior and prior are Log-Normal.

Parameters:
  • batch_index (Tensor) – Batch indices for batch-specific library size mean and variance.

  • variational_library_mean (Tensor) – Mean of variational Log-Normal.

  • variational_library_var (Tensor) – Variance of variational Log-Normal.

  • library (Tensor) – Sampled library size.

Return type:

Tensor

Returns:

KL divergence for each data point. If number of latent samples == 1, the tensor has shape (batch_size, ). If number of latent samples > 1, the tensor has shape (n_samples, batch_size).

ContrastiveVAE.loss(concat_tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#

Computes loss terms for contrastiveVI.

Parameters:
  • concat_tensors (dict[str, dict[str, Tensor]]) – Tuple of data mini-batch. The first element contains background data mini-batch. The second element contains target data mini-batch.

  • inference_outputs (dict[str, dict[str, Tensor]]) – Dictionary of inference step outputs. The keys are “background” and “target” for the corresponding outputs.

  • generative_outputs (dict[str, dict[str, Tensor]]) – Dictionary of generative step outputs. The keys are “background” and “target” for the corresponding outputs.

  • kl_weight (float (default: 1.0)) – Importance weight for KL divergence of background and salient latent variables, relative to KL divergence of library size.

Return type:

LossOutput

Returns:

An scvi.module.base.LossOutput instance that records the following: loss

One-dimensional tensor for overall loss used for optimization.

reconstruction_loss

Reconstruction loss with shape (n_samples, batch_size) if number of latent samples > 1, or (batch_size, ) if number of latent samples == 1.

kl_local

KL divergence term with shape (n_samples, batch_size) if number of latent samples > 1, or (batch_size, ) if number of latent samples == 1.

static ContrastiveVAE.reconstruction_loss(x, px_rate, px_r, px_dropout)[source]#

Computes likelihood loss for zero-inflated negative binomial distribution.

Parameters:
  • x (Tensor) – Input data.

  • px_rate (Tensor) – Mean of distribution.

  • px_r (Tensor) – Inverse dispersion.

  • px_dropout (Tensor) – Logits scale of zero inflation probability.

Return type:

Tensor

Returns:

Negative log likelihood (reconstruction loss) for each data point. If number of latent samples == 1, the tensor has shape (batch_size, ). If number of latent samples > 1, the tensor has shape (n_samples, batch_size).

ContrastiveVAE.sample()[source]#

Generate samples from the learned model.