Thank you both for your suggestions. I have done a couple of optimizations by combining your recommendations.
Previous implementation
The intuition of Martin was right and my previous implementation was overfitting.
Optimization 1
I introduced the following changes in the code:
- Reduced highly variable genes to 3000:
Code
sc.pp.highly_variable_genes(adata, n_top_genes=3000, subset = True, layer = 'soupX_counts', flavor = "seurat_v3", batch_key="Replicate")
- Used
SCANVI
and replacedcategorical_covariate_keys
withbatch_key
.
Code
scvi.model.SCANVI.setup_anndata(adata, layer = "soupX_counts",
batch_key="Replicate",
labels_key="celltypist_cell_types_high_resolution",
unlabeled_category="Unknown",
continuous_covariate_keys=['pct_counts_mt',
'total_counts'])
- Reduced
max_epochs
to400
,early_stopping_patience
to10
and removedlimit_train_batches
parameter.
Code
model.train(
accelerator="gpu",
devices=1,
early_stopping=True,
train_size=train_size,
max_epochs=400,
early_stopping_patience=10,
batch_size=1024,
)
Results: less overfitting but batch effect remained.
Optimization 2
I implemented the recommendation of Can and engineered a new Replicate feature for the batch correction called Replicate2
.
Code
adata.obs['Replicate2'] = [f"{x}_{y}" for x,y in zip(adata.obs['Replicate'], adata.obs['Condition'])]
Results: less overfitting (but seems to need a bit more time to converge), batch effect seems to have disappeared.
Optimization 3
I reduced n_latent
to 10
and increased the early_stopping_patience
to 20
.
Results: better convergence and best integration so far.
New problem
Although optimization 3 yielded the best results, after performing differential gene expression I got no DEGs (which is very weird). I think the issue is with the way the feature Replicate2
has been engineered which I find quite unintuitive. Let’s me give some clarifications on the experimental design:
I have the following columns in the dataset:
Replicate: {p1,p2,p4,p5,p6}
Condition: {t0, t1, t2}
The timepoints are called Condition because they concern before treatment (t0), after treatment (t1) and after many treatments (t2). So DEGs should be identified between t0 vs t1
, t1 vs t2
.
Based on the suggestion of Can, the new “Replicate” feature should be like following:
Replicate2: {p1_t0, p1_t1, p1_t2, ..., p6_t0, p6_t1, p6_t2}
I think by performing batch correction using Replicate2
all the biological signal between the timepoints got disappeared. Therefore, I think that batch correction should be performed using the Replicate
instead of the Replicate2
feature. Could you please clarify why correcting using Replicate2
is better?