What is the best way to extract a "full" batch effect corrected count matrix from scVI model?


I would like to get “full” batch corrected count matrix from trained scVI model for other scRNA analysis tools requiring count data. And as I understood, I could extract the batch corrected frequency of each transcript in each cell from model.get_normalized_expression() with transform_batch parameter. And to get the “corrected count”, I thought I should use the function with library_size='latent' as in the scVI paper(Expected count = 𝜌(frequency) X I(cell specific scaling factor))
However, if the adata.X in the model had only hvgs which is used for training, the model return corrected count with only hvgs. But I could get full corrected matrix when I load model with adata which had full genes in adata.X.
So, my questions are

  1. Is using get_normalized_expression() function with library_size='latent' and batch_size="my batches" parameters the best way to get the batch corrected count matrix?
  2. Is it okay to use the model trained only with hvgs(about 1000~10,000 genes) with full adata.X to get the full count matrix?
  3. If all of my assumptions are valid, could I use the normalized data based on that batch corrected count for plotting expression based plots such as dot plot or feature plot?

Thanks a lot!!


Hi jjuhh,

First, batch_size is the parameter that determines how many cells are run through the model in each ‘chunk’ of data. Do you mean the parameter transform_batch?

Other than that, what you describe will get you scale parameters with a learned library size for each cell and gene in the model. If you want counts, you will need to take these scale parameter estimates, pass them to a count distribution, and sample from that.

The meaning of ‘correcting’ the expression estimates is that you wish to analyze the data as if it came from one single batch. Similarly, for counts, you would want to ‘correct’ the technical factor of count depth so that the count depth is the one same value for all cells.

What you can do is run scales = get_normalized_expression(adata, transform_batch = "My favorite batch", library_size = 100000). Which batch you pick and which standardized library size you pick is arbitrary, the point is that all scales will be as if they came from the same batch so you can compare expression between all your samples. Again, this gives you scales that you need to pass to a function that samples from the given distribution. As an example for the sake of simplicity, if you trained the model with a parameter gene_likelihood = 'Poisson', than to make counts you would sample from a Poisson distribution with your scales from above. Like this: counts = torch.distributions.poisson.Poisson(scales).sample().

The model can only work on the genes you have trained it on. If you want to do this I would suggest training the model on all genes. If you are exclusively using it for this normalization task, the slightly more ‘blurry’ representation you get is probably not an issue.

I would suggest plotting the continuous values you get from get_normalized_expression() with a standardized library size directly instead. The issue with counts is that are discrete, which makes them hard to plot. The estimates of the expression levels you get from get_normalized_expression() will include all the normalization you would attempt to make the plot readable.

Hope this helps!

1 Like