What is the best way to extract a "full" batch effect corrected count matrix from scVI model?

Hi jjuhh,

First, batch_size is the parameter that determines how many cells are run through the model in each ‘chunk’ of data. Do you mean the parameter transform_batch?

Other than that, what you describe will get you scale parameters with a learned library size for each cell and gene in the model. If you want counts, you will need to take these scale parameter estimates, pass them to a count distribution, and sample from that.

The meaning of ‘correcting’ the expression estimates is that you wish to analyze the data as if it came from one single batch. Similarly, for counts, you would want to ‘correct’ the technical factor of count depth so that the count depth is the one same value for all cells.

What you can do is run scales = get_normalized_expression(adata, transform_batch = "My favorite batch", library_size = 100000). Which batch you pick and which standardized library size you pick is arbitrary, the point is that all scales will be as if they came from the same batch so you can compare expression between all your samples. Again, this gives you scales that you need to pass to a function that samples from the given distribution. As an example for the sake of simplicity, if you trained the model with a parameter gene_likelihood = 'Poisson', than to make counts you would sample from a Poisson distribution with your scales from above. Like this: counts = torch.distributions.poisson.Poisson(scales).sample().

The model can only work on the genes you have trained it on. If you want to do this I would suggest training the model on all genes. If you are exclusively using it for this normalization task, the slightly more ‘blurry’ representation you get is probably not an issue.

I would suggest plotting the continuous values you get from get_normalized_expression() with a standardized library size directly instead. The issue with counts is that are discrete, which makes them hard to plot. The estimates of the expression levels you get from get_normalized_expression() will include all the normalization you would attempt to make the plot readable.

Hope this helps!
/Valentine

2 Likes