Insufficient batch correction for certain cell-types

Hello,

Dataset

I have a scRNASeq dataset that was based on the following experimental design:

timepoint Nsamples
t0 5
t1 3
t2 4

Aim

Correct for interindividual variability across timepoints and study gene expression for each cell-type between timepoints.

Batch correction

In order to remove interindividual variability, first I created two columns:

  1. One column named Replicate containing sample IDs.
  2. Second column denoting the timepoint t1, t2 and t3.

then I performed batch correction as following:

  1. Select 4000 highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=4000, subset = True, layer = 'soupX_counts', flavor = "seurat_v3", batch_key="Replicate")

soupX_counts were provided by the SoupX package duing the QC steps to remove any possible mRNA contamination.

  1. Setup SCVI model
scvi.model.SCVI.setup_anndata(adata, layer = "soupX_counts",
                             categorical_covariate_keys=["Replicate"],
                             continuous_covariate_keys=['pct_counts_mt', 'total_counts'])

I tried tweaking the parameters (as suggested in other posts on batch correction) and the following provided the best results for me:

model = scvi.model.SCVI(
        adata,
        n_layers=2,
        n_hidden=200,
        n_latent=20,
        gene_likelihood='zinb',
        dispersion="gene"
        )
  1. Define training size
if 0.1 * adata.n_obs < 20000:
    train_size = 0.9
else:
    train_size = 1-(20000/adata.n_obs)
  1. Train SCVI model
model.train(
        accelerator="gpu", 
        devices=1,
        early_stopping=True,
        train_size=train_size,
        early_stopping_patience=400,
        max_epochs=10000,
        batch_size=1024,
        limit_train_batches=20
        )

Results

Despite the batch correction step, still there seems to be some batch effect for certain cell-types especially classical monocytes.

In the figure above cells from timepoint t0 (PRE) were selected and UMAP was performed with scanpy. Cell-type annotation was performed with CellTypist.

Questions

  1. Do you think the observed batch effect appears strong enough to affect differential expression?
  2. Did the batch correction get affected negatively by the unbalanced dataset?
  3. Do you have any suggestion to improve the batch correction process?

Thank you in advance.

Hi, I have a couple of clarifying questions:

  • What is the size of your dataset? Depending on that, the number of maximum epochs might be too large. How long does it take for early stopping to kick in?
  • What’s the reason behind setting limit_train_batches=20?

The size of the merged dataset (i.e t0+t1+t2) is 86500 cells. The activation of the early stopping depended on the early_stopping_patience value. When the latter was set at 45, early stopping happened within 7 minutes, at 150 within 20 minutes and at 400 close to 50 minutes.

Regarding the limit_train_batches, I’m not sure whether my reasoning is correct but the idea was to constrain RAM usage and instead of loading all cells during training, I wanted to load batches of 20 till all dataset was used. However, I’m not sure whether this is correct. Does limit_train_batches gets multiplied by batch_size and provides the total number of cells used in each epoch? If that’s true, then I’m using only 20k cells in each epoch.

Could you also specify the number of epochs for which your model is training? I have a feeling that with those values of max_epochs and early_stopping_patience, your model is most likely overfitting on the training set. This can be checked by plotting "elbo_validation" and "elbo_train" in model.history post-training. If the validation ELBO is significantly higher than its training counterpart, then it’s likely that the model overfit. Something in the ballpark of max_epochs=400 and early_stopping_patience=10 might be more appropriate for this dataset size.

My understanding of limit_train_batches is that it will run that number of minibatches during each training epoch, so yes the model would be seeing about 20k observations during training. Your main memory usage will mostly be set by your full dataset size, and your GPU RAM usage will be set by batch_size and the model size, so limit_train_batches will only affect how long it takes for each epoch to complete.

Hi, I think overall it looks fine integration wise. Yes, it will likely affect DE genes between replicates (between cell-types I guess it might be fine. If you want to get classical monocytes better integrated, I would check DE gene between those batches. It’s a very typical behaviour that monocytes get activated during the experiment (in humans higher in IL1B but also FOS and JUN). scVI usually does not integrate this (it’s helpful as it would otherwise do overintegration). If you think it’s critical to get better integration, I would reduce the number of genes and n_latent (to maybe 1500 genes). I’m also not sure whether Replicate is a unique sample or whether it has the same value for different samples at different time points. I would in that case recommend to replace Replicate by a string concatenation of Replicate and Timepoint. I hope this is clear.
ScANVI might also be helpful to get better integration (there you inform the model about the meaningful cellular variation).
I would generally recommend using batch_key instead of categorical covariate key if you use a single category as it offers more downstream capabilities (see transform_batch in the codebase).

Thank you both for your suggestions. I have done a couple of optimizations by combining your recommendations.

Previous implementation

The intuition of Martin was right and my previous implementation was overfitting.

Optimization 1

I introduced the following changes in the code:

  1. Reduced highly variable genes to 3000:
Code
sc.pp.highly_variable_genes(adata, n_top_genes=3000, subset = True, layer = 'soupX_counts', flavor = "seurat_v3", batch_key="Replicate")
  1. Used SCANVI and replaced categorical_covariate_keys with batch_key.
Code
scvi.model.SCANVI.setup_anndata(adata, layer = "soupX_counts", 
               batch_key="Replicate",
               labels_key="celltypist_cell_types_high_resolution", 
               unlabeled_category="Unknown",
               continuous_covariate_keys=['pct_counts_mt',
               'total_counts'])
  1. Reduced max_epochs to 400, early_stopping_patience to 10 and removed limit_train_batches parameter.
Code
model.train(
        accelerator="gpu", 
        devices=1,
        early_stopping=True,
        train_size=train_size,
        max_epochs=400,
        early_stopping_patience=10,
        batch_size=1024,
        )

Results: less overfitting but batch effect remained.

Optimization 2

I implemented the recommendation of Can and engineered a new Replicate feature for the batch correction called Replicate2.

Code
adata.obs['Replicate2'] = [f"{x}_{y}" for x,y in zip(adata.obs['Replicate'], adata.obs['Condition'])]

Results: less overfitting (but seems to need a bit more time to converge), batch effect seems to have disappeared.

Optimization 3

I reduced n_latent to 10 and increased the early_stopping_patience to 20.

Results: better convergence and best integration so far.

New problem

Although optimization 3 yielded the best results, after performing differential gene expression I got no DEGs (which is very weird). I think the issue is with the way the feature Replicate2 has been engineered which I find quite unintuitive. Let’s me give some clarifications on the experimental design:

I have the following columns in the dataset:
Replicate: {p1,p2,p4,p5,p6}
Condition: {t0, t1, t2}
The timepoints are called Condition because they concern before treatment (t0), after treatment (t1) and after many treatments (t2). So DEGs should be identified between t0 vs t1, t1 vs t2.

Based on the suggestion of Can, the new “Replicate” feature should be like following:
Replicate2: {p1_t0, p1_t1, p1_t2, ..., p6_t0, p6_t1, p6_t2}

I think by performing batch correction using Replicate2 all the biological signal between the timepoints got disappeared. Therefore, I think that batch correction should be performed using the Replicate instead of the Replicate2 feature. Could you please clarify why correcting using Replicate2 is better?

Can you also share your line of code for DE (I need to know how batch_key is used there)?
In general batch effects exist between sequencing samples/sample processings. I assume the time points of one replicate are independent (different patients or mice). If it’s the same patient sampled multiple times using replicate for integration makes sense (otherwise your design could contain e.g. male and female for the same replicate name and this will lead to low integration).
I’m case, it’s a single replicate across time points. I still think your first integration looks fine and your main point was how to get more integration.

Hi Can,

Regarding the experimental design, there are 5 patients (p1,p2,p4,p5,p6) each one provided samples at 3 timepoints (t0, t1, t2) as shown in the table below.

Replicate Condition Replicate2
p1 t0 p1_t0
p1 t1 p1_t1
p1 t2 p1_t2
p2 t0 p2_t0
p2 t1 p2_t1
p2 t2 p2_t2
p4 t0 p4_t0
p4 t1 p4_t1
p4 t2 p4_t2
p5 t0 p5_t0
p5 t1 p5_t1
p5 t2 p5_t2
p6 t0 p6_t0
p6 t1 p6_t1
p6 t2 p6_t2

The idea is to remove inter-individual differences between patients across timepoints (Condition) and get DEGs between timepoints (Condition). My initial implementation included the column “Replicate” as batch_key and my last implementation included “Replicate2” as batch_key.

Below is my code for DEGs:

pairs = [("t0", "t1"), ("t1", "t2"), ("t0", "t2")]

def getDEGs(cellType):
    global DEGs
    for pair in pairs:
        p1, p2 = pair
        cellIDX1 = ((adata.obs['celltypist_cell_types_low_resolution'] == cellType) & (adata.obs['Condition'] == p1))
        cellIDX2 = ((adata.obs['celltypist_cell_types_low_resolution'] == cellType) & (adata.obs['Condition'] == p2))
        tmp = model.differential_expression(idx1=cellIDX1, idx2=cellIDX2, mode='change', batch_correction=True)
        DEGs[f"{cellType}_{p1}_vs_{p2}"] = tmp

DEGs = {}
for cellType in celltypesLowResolution:
    getDEGs(cellType)

When “Replicate2” was used as batch_key no DEGs were identified. To my understanding if batch_correction is True then by default bathcid1 and batchid2 use all the categories specified in the batch_key with setup_anndata. So, I think batch correcting with “Replicate2” leads to betten integration but removes all biological signal.

I think in your setup it is fine to use replicate as batch_key. The different donors will be the reason for most of the batch effect likely, so there’s little batch effect between time points. It is fine to use Replicate2 in this case though you want to set batch_correction=False in model.differential_expression or better set batch_id1 to all batches of condition1 and for batch_id2 to condition2. This way the model won’t correct gene expression for the covariate of interest and will generate those counts for the correct batch.