Generating Synthetic scRNA data

Hello,

I have been looking around to find a way for generating synthetic scRNA-seq data using one of the scVI models. Google brought me to a documentation on inference from an older version of your package (found here -> https://www.scvi-tools.org/en/0.6.8/scvi.inference.html?highlight=inference#module-scvi.inference.

It seems that scvi.inference is no longer an option in version 0.8.1, and using the 0.6.8 version of the package from PyPi resulted in lots of errors.

Is there an easy way of generating synthetic data using the current package, or would I need to manually input a random z vector to the decoder of a trained model, say LinearSCVI?

Thank you very much for your time and help.

This function would generate posterior predictive samples, but I imagine you want to sample from the prior and not the posterior.

Something like this would work. There might be some import errors, CUDA errors, etc. Something to keep in mind is that here the library size is assumed to be 1, and the batch index of all cells is 0, which could be a problem if the model is trained on multiple batches.

    import torch
    from torch.distributions import Normal


    @torch.no_grad()
    def prior_predictive_sample(
        self,
        n_samples: int = 1000,
    ) -> np.ndarray:
        r"""
        Generate observation samples from the prior predictive distribution.

        Parameters
        ----------
        n_samples
            Number of samples.

        Returns
        -------
        x_new : :py:class:`torch.Tensor`
            tensor with shape (n_cells, n_genes, n_samples)
        """
        if self.model.gene_likelihood not in ["zinb", "nb", "poisson"]:
            raise ValueError("Invalid gene_likelihood.")


        # Sampling
        qz_m = torch.zeros(n_samples, self.model.n_latent)
        qz_v = torch.ones(n_samples, self.model.n_latent)
        z = Normal(qz_m, qz_v).sample()

        dec_batch_index = torch.zeros(n_samples, 1)
        y = torch.zeros(n_samples, 1)
        library = torch.zeros(n_samples, 1) # gets exponentiated

        px_scale, px_r, px_rate, px_dropout = self.model.decoder(
            self.model.dispersion, z, library, dec_batch_index, y
        )
        if self.model.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.model.n_labels), self.px_r
            )  # px_r gets transposed - last dimension is nb genes
        elif self.model.dispersion == "gene-batch":
            px_r = F.linear(one_hot(dec_batch_index, self.model.n_batch), self.model.px_r)
        elif self.model.dispersion == "gene":
            px_r = self.model.px_r
        px_r = torch.exp(px_r)

        if self.model.gene_likelihood == "poisson":
            l_train = px_rate
            l_train = torch.clamp(l_train, max=1e8)
            dist = torch.distributions.Poisson(
                l_train
            )  # Shape : (n_samples, n_cells_batch, n_genes)
        elif self.model.gene_likelihood == "nb":
            dist = NegativeBinomial(mu=px_rate, theta=px_r)
        elif self.model.gene_likelihood == "zinb":
            dist = ZeroInflatedNegativeBinomial(
                mu=px_rate, theta=px_r, zi_logits=px_dropout
            )
        else:
            raise ValueError(
                "{} reconstruction error not handled right now".format(
                    self.model.gene_likelihood
                )
            )

        exprs = dist.sample()

        return exprs.cpu().numpy() # Shape (n_cells, n_genes)