Insufficient batch correction for certain cell-types

Thank you both for your suggestions. I have done a couple of optimizations by combining your recommendations.

Previous implementation

The intuition of Martin was right and my previous implementation was overfitting.

Optimization 1

I introduced the following changes in the code:

  1. Reduced highly variable genes to 3000:
Code
sc.pp.highly_variable_genes(adata, n_top_genes=3000, subset = True, layer = 'soupX_counts', flavor = "seurat_v3", batch_key="Replicate")
  1. Used SCANVI and replaced categorical_covariate_keys with batch_key.
Code
scvi.model.SCANVI.setup_anndata(adata, layer = "soupX_counts", 
               batch_key="Replicate",
               labels_key="celltypist_cell_types_high_resolution", 
               unlabeled_category="Unknown",
               continuous_covariate_keys=['pct_counts_mt',
               'total_counts'])
  1. Reduced max_epochs to 400, early_stopping_patience to 10 and removed limit_train_batches parameter.
Code
model.train(
        accelerator="gpu", 
        devices=1,
        early_stopping=True,
        train_size=train_size,
        max_epochs=400,
        early_stopping_patience=10,
        batch_size=1024,
        )

Results: less overfitting but batch effect remained.

Optimization 2

I implemented the recommendation of Can and engineered a new Replicate feature for the batch correction called Replicate2.

Code
adata.obs['Replicate2'] = [f"{x}_{y}" for x,y in zip(adata.obs['Replicate'], adata.obs['Condition'])]

Results: less overfitting (but seems to need a bit more time to converge), batch effect seems to have disappeared.

Optimization 3

I reduced n_latent to 10 and increased the early_stopping_patience to 20.

Results: better convergence and best integration so far.

New problem

Although optimization 3 yielded the best results, after performing differential gene expression I got no DEGs (which is very weird). I think the issue is with the way the feature Replicate2 has been engineered which I find quite unintuitive. Let’s me give some clarifications on the experimental design:

I have the following columns in the dataset:
Replicate: {p1,p2,p4,p5,p6}
Condition: {t0, t1, t2}
The timepoints are called Condition because they concern before treatment (t0), after treatment (t1) and after many treatments (t2). So DEGs should be identified between t0 vs t1, t1 vs t2.

Based on the suggestion of Can, the new “Replicate” feature should be like following:
Replicate2: {p1_t0, p1_t1, p1_t2, ..., p6_t0, p6_t1, p6_t2}

I think by performing batch correction using Replicate2 all the biological signal between the timepoints got disappeared. Therefore, I think that batch correction should be performed using the Replicate instead of the Replicate2 feature. Could you please clarify why correcting using Replicate2 is better?