Directly accessing scVI's decoder

Hi, thanks for this really exciting tool.

I’m trying out some trajectory analysis where I would like to map predicted trajectories from the latent space back to the original expression space; in essence I would like to directly pass the points of these trajectories (which are not part of the anndata object) through the model’s decoder.

I see that the VAE class has a function ‘generative()’ which looks like what I am after, but I am unfamiliar with the architecture of the model and I am unsure how best to access this function. What would be the best way for a user to directly pass latent space points through the model’s decoder?

Any help would be much appreciated, thanks!

Rory

Hi! Thanks for using scvi-tools.

It might be helpful to look at the structure of the get_latent_representation() method.

        adata = self._validate_anndata(adata)
        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )
        latent = []
        for tensors in scdl:
            inference_inputs = self.module._get_inference_input(tensors)
            outputs = self.module.inference(**inference_inputs)
            qz_m = outputs["qz_m"]
            qz_v = outputs["qz_v"]
            z = outputs["z"]

            if give_mean:
                # does each model need to have this latent distribution param?
                if self.module.latent_distribution == "ln":
                    samples = Normal(qz_m, qz_v.sqrt()).sample([mc_samples])
                    z = torch.nn.functional.softmax(samples, dim=-1)
                    z = z.mean(dim=0)
                else:
                    z = qz_m

            latent += [z.cpu()]
        return np.array(torch.cat(latent))

Basically what you want to do is create a dataloader with your values of z, iterate over it, and called self.module.generative(...)

Note that the signature of generative in VAE is

    @auto_move_data
    def generative(
        self, z, library, batch_index, cont_covs=None, cat_covs=None, y=None
    ):

So you’ll need to decode z with a batch_index batch size by 1 tensor of ints. The library size you can make a torch tensor of 1s, as it only affects the computation of px_rate. From generative return dictionary you’ll want px_scale.

Please feel free to follow up with additional questions.

1 Like

Hi Adam,

Congratulations on the recent publication of scvi-tools.

Regarding your response, could you please clarify the dimensions of z and library_size tensor mentioned above? I have thought that the library_size tensor should have the dimensions equal to #cells x #genes, while dimensions of z should be #cells x #latents, is it correct? Also, if I am interested in the px_rate (scaled gene expression, count values), not px_scale (normalized gene expression), what values should library_size be assigned instead of 1s?

Once again, thanks for such a great tool. I am looking forward to your feedback.

Best,

Tu,

1 Like

z has dimension (128 cells by 10 dimensions default)

library is (128 by 1)

It’s not clear to me what you’re doing so I’m not sure what value you want ideally, but you can likely just use the value that the model is using (tends to be observed library size)

1 Like

Hi Adam,

Thanks for your prompt response.

What I want to output is gene expression values after rescaling the normalized expression (px_scale) returned by the generative function with the observed library size. As you suggested, this can be achieved by calling the generative function with the library_size set to 1s.

Could we expect the same results by using the observed library size as the input parameter to the generative function, then extracting the px_rate from the output instead? (In fact I have tried this yet all the returned values for px_rate are set to Inf).

Thanks again for a great tool, and I am looking forward to hearing your feedback!

Kind regards,

So this is what we call px_rate. But I’m unclear as to whether you’re building your own model or you want this from scvi.model.SCVI. In the latter case, both px_scale and px_rate = library * px_scale can be obtained from this function

1 Like

Hi Adam,
Thank you very much for the explanation and kind support. I was able to execute the function as you suggested.
Best regards,

Hi Adam,
Sorry for bothering you again!

I would like to ask you a question about the relationship between library, px_scale, and px_rate.

You have mentioned in one of the messages in this post that: px_rate = library * px_scale. However, when I carefully checked the codes (forward function of the DecoderSCVI class), the actual calculation is: px_rate = torch.exp(library) * px_scale.

So, what is the role of the torch.exp function here, and how does it affect downstream calculations if library is assigned with the sum of the observed counts (by setting use_observed_lib_size = True), which usually takes large values.

Thank you very much in advance!
Best,

Hi there,

Sorry for the confusing syntax. In this case the library term is actually the log library size, which gets set in the inference function scvi.module._vae - scvi-tools. So the exp just reverses this operation. I believe the reason for this convention is to keep values stable in the case that use_observed_lib_size=False.

1 Like

Hi Justine,
Thanks for your prompt feedback and for clarifying about the syntax. It’s clear to me now.
Best regards,

Hello!

I’m trying to do exactly the same that the original messages says, using the decoder part of a trained scVI model with data that it’s not part of the original anndata object.

My latent is n_cells x ndim, my library and my batch_index is an array of ones of size n_cells x 1, but I’m receiving an error saying that 'int' object is not callable when the generative function runs the self.decoder.
Any idea? I’m not using any data loader to get batches, since I don’t want to take gradients, just decode the data. Plus, to create a valid data loader I would need to call _validate_anndata with the anndata that just contains the latents, that has a different shape that the one used to train the model, so it returns error as well.

Thanks!