What is the best way to extract a "full" batch effect corrected count matrix from scVI model?

Hello!

I would like to get “full” batch corrected count matrix from trained scVI model for other scRNA analysis tools requiring count data. And as I understood, I could extract the batch corrected frequency of each transcript in each cell from model.get_normalized_expression() with transform_batch parameter. And to get the “corrected count”, I thought I should use the function with library_size='latent' as in the scVI paper(Expected count = 𝜌(frequency) X I(cell specific scaling factor))
However, if the adata.X in the model had only hvgs which is used for training, the model return corrected count with only hvgs. But I could get full corrected matrix when I load model with adata which had full genes in adata.X.
So, my questions are

  1. Is using get_normalized_expression() function with library_size='latent' and batch_size="my batches" parameters the best way to get the batch corrected count matrix?
  2. Is it okay to use the model trained only with hvgs(about 1000~10,000 genes) with full adata.X to get the full count matrix?
  3. If all of my assumptions are valid, could I use the normalized data based on that batch corrected count for plotting expression based plots such as dot plot or feature plot?

Thanks a lot!!

3 Likes

Hi jjuhh,

First, batch_size is the parameter that determines how many cells are run through the model in each ‘chunk’ of data. Do you mean the parameter transform_batch?

Other than that, what you describe will get you scale parameters with a learned library size for each cell and gene in the model. If you want counts, you will need to take these scale parameter estimates, pass them to a count distribution, and sample from that.

The meaning of ‘correcting’ the expression estimates is that you wish to analyze the data as if it came from one single batch. Similarly, for counts, you would want to ‘correct’ the technical factor of count depth so that the count depth is the one same value for all cells.

What you can do is run scales = get_normalized_expression(adata, transform_batch = "My favorite batch", library_size = 100000). Which batch you pick and which standardized library size you pick is arbitrary, the point is that all scales will be as if they came from the same batch so you can compare expression between all your samples. Again, this gives you scales that you need to pass to a function that samples from the given distribution. As an example for the sake of simplicity, if you trained the model with a parameter gene_likelihood = 'Poisson', than to make counts you would sample from a Poisson distribution with your scales from above. Like this: counts = torch.distributions.poisson.Poisson(scales).sample().

The model can only work on the genes you have trained it on. If you want to do this I would suggest training the model on all genes. If you are exclusively using it for this normalization task, the slightly more ‘blurry’ representation you get is probably not an issue.

I would suggest plotting the continuous values you get from get_normalized_expression() with a standardized library size directly instead. The issue with counts is that are discrete, which makes them hard to plot. The estimates of the expression levels you get from get_normalized_expression() will include all the normalization you would attempt to make the plot readable.

Hope this helps!
/Valentine

2 Likes

Valentine makes a lot of good points, I just want to standardize some of the terminology here.

Single experimental batch case

Denoised - Denoised expression for scVI would represent a conditional expectation for cell n and gene g: \mathbb{E}_{q(z_n\mid x_n)}[x_{ng} | z_n]. There are some caveats:

  1. Denoised is not necessarily normalized, as the conditional expectation would use the observed or estimated library size of cell n
  2. In the case of gene_likelihood="zinb", my expression is inaccurate, it would technically be \mathbb{E}_{q(z_n\mid x_n)}[x_{ng} | z_n, \text{not zero-inflated}], or the conditional expectation when also conditioning on the non-zero component of the ZINB distribution. This follows the original approach of the scVI manuscript where 0s captured by the ZI component were excess and undesirable. However, this hypothesis oversimplifies the biology and it’s currently recommended to train with gene_likelihood="nb" instead. @Valentine_Svensson can also likely explain very well why :slight_smile:

Normalized - Denoised and scaled. We know that \mathbb{E}_{q(z_n\mid x_n)}[x_{ng} | z_n] = \ell_n [f(z_n)]_g , where \ell_n is the scalar library size for cell n and f(z_n) would be the denoised frequency of gene expression for gene g in cell n . The library_size param of get_normalized_expression allows manipulating \ell_n.

Multi-batch case

In the case of multiple batches f(z_n), also called the decoder is actually represented as f(z_n, s_n), where s_n is a one-hot experimental batch encoding. Therefore, while z_n may be “batch-corrected”, the decoder always maps to the observed data, which clearly has batch specific effects.

To mitigate this, we “counterfactually decode”, or decode using f(z_n, s_n'), where s_n' is the value of a different batch (different than the observed). This mechanism of decoding (transform_batch) allows us to produce denoised and normalized expression values.

If for some reason we wanted to create batch-corrected counts, we’d want to denoise and not normalize, and then sample from the likelihood distribution as Valentine described. This functionality is not currently in the codebase.

It’s also important to mention that transform_batch will not work perfectly when there is confounding in the experimental design. Making counterfactual decodings for cell types where s_n' was not observed may have unintended consequences.

I can expand more on points of confusion, but I hope this helps clear up some of what this function does.

2 Likes

hello,
Thanks for your reply. But what if we used a negative binomial instead of a Poisson ? How would you modify this line ?
counts = torch.distributions.poisson.Poisson(scales).sample()

Thank you both for your detailed explanations! Standardizing the wording will be helpful for me. Thanks again!! :blush: