Suggestion on parameters for training scvi model

Hello!

I want to integrate 67 different samples (batches) from three different datasets using scVI. I have a total of 235k cells, and as suggested in the Atlas integration tutorial I integrated selecting only the top 2000 most variable features in common between all cells, as this method better removes the batch effect.

I set up the anndata and used the following settings:

scvi.model.SCVI.setup_anndata(anndata, layer="counts", batch_key="batch")
model = scvi.model.SCVI(anndata, n_layers=2, n_latent=30, gene_likelihood="zinb")
model.train()

By default, the training runs 34 epochs (calculated as 20.000/n_cells*400 as defined in the function).

By plotting the training metrics stored in model.history(), I get the following trends:

  • train_loss_step: fluctuating trend (slightly decreasing for the last epochs)
  • train_loss_epoch: decreasing trend and converging right after 10/15 epochs
  • elbo_train: decreasing trend (not yet converging, seems to further decrease if further epochs are used)
  • reconstruct_loss_train: decreasing trend and converging right after 25 epochs
  • kl_local_train: rises to a peak after 2 epochs and then decreases (not seeming to converge for only 35 epochs)

How can I check if the model performs well?
Are the trends described above enough to check if the model is good or not? What should the best trends be for a good model?

Should I manually set a higher number of epochs for better performance (e.g. 50)?
Should I increase the n_latent parameter to increase the information in my data (e.g. n_latent=50)?

Thanks a lot in advance for your help!

Hey, thanks for your question! I’d suggest passing in check_val_every_n_epoch=1 into train() in order to compute validation losses as well. If you find that elbo_validation, validation_loss, and reconstruction_loss_validation are converging at the end of training, then your model is probably fine. Otherwise, you can tweak the learning rate or number of epochs depending on what’s going on. Hope this helps!

1 Like

Thanks you very much for your fast reply! This helped me a lot.

I also have another question: I would like to integrate my samples but instead of using only the top 2000 highly variable genes, I would like to use all the genes.
However, the batches do not all have the same number of genes, so when creating the concatenated matrix I would do

anndata_dir = 'C:/Users/Martina/Desktop/CAR-T Atlas Data/AnnData'
list_files = os.listdir(anndata_dir)
anndata_list = []

for filename in list_files:
    file_path = os.path.join(anndata_dir, filename)
    anndata_obj = ad.read_h5ad(file_path)
    anndata_list.append(anndata_obj)

concatenated_anndata = ad.concat(anndata_list, axis=0, join='outer')

so that I can keep all the genes, and for those cells not having that genes I have 0 counts added from the ad.concat function.

I would like to do this in order to better integrate the data taking into account all the possible variability and then run the differential gene expression over all the genes to better characterise all the cells.

Would you recommend doing this? Does it remove batch effects efficiently? Or should I use a different approach? If so, which approach would you recommend?

What makes me doubt of this is when it comes to the differential gene expression: please correct me if I’m wrong, but I would think that by doing this, the cells from batches that do not express some genes, and thus have added ‘zeros’ from the ad.concat function through the join=‘outer’ setting, would get biased when calculating the DGE since I added these 0 counts (maybe they would be expressed but I don’t have that information).

Thank you so much for your help!!

Thank you so much!

I actually have another question: I would like to integrate all my samples using all the genes instead of only the top 2000 most variable genes.

I would do this in order to better characterise the variability and differences between my batches, and to subsequently run the differential gene expression on my cells using all the genes.

However, not all the samples have the same number of genes, so that I would do the following:


anndata_dir = 'C:/Users/Martina/Desktop/AnnData'
list_files = os.listdir(anndata_dir)
anndata_list = []

for filename in list_files:
    file_path = os.path.join(anndata_dir, filename)
    anndata_obj = ad.read_h5ad(file_path)
    anndata_list.append(anndata_obj)

concatenated_anndata = ad.concat(anndata_list, axis=0, join='outer')

In this way, the cells from batches that do not express certain genes are added a column for the corresponding genes with zero counts.

Would you recommend doing this (integrating samples using all the genes)? Does the model then effectively remove batch effects and correct the counts? Would I get good DGE results? Or do you recommend a different approach? If so, what would you recommend?

The only thing that makes me doubt of this is for the DGE results: please correct me if I’m wrong, but by concatenating all the sample with the ad.concat function with the join='outer' setting, I add zero counts on genes for those cells for which I do not have information about that same gene. I would then think that DGE results would be biased as I assigned zero expression for that genes on those cells: the fact that I miss info for those genes do not mean that these cells do not express the gene.

What is your opinion on this?

Thank you a lot for your help!!