Train scVI on a sampled dataset

Hey,

We recently started using Leverage score sampling from Seurat that samples large datasets in a way that preserves rare cell types. Assuming I sampled 10% of cells from every scVI batch that I have and trained an scVI model using it, how do I project the remaining 90% of the cells? Since the scVI model has already learned “batch effects” from my sampled data, is it possible to do a forward pass of the trained model to get latent space for the remaining 90% of cells per batch?

Thanks

We support now passing train and validation indices. We would be very excited to see some benchmarking with passing stratified indices and model performance. I would assume one could speed up training quite significantly. You can enable it by defining validation_indices to datasplitter_kwargs in model.train.
Please update about the results as we might be interested to highlight it as a tutorial.

1 Like

So, just to be clear I understand, we are passing 10% of the sampled data for training and the remaining as a validation set?

You can also add cells to the test_set. The validation set is evaluated once per epoch. Adding them to the test_set means they are not used during the full training run. The second one would be more efficient. I would still keep 10% validation cells - maybe using the same splitting technique.

Okay, I would try with 10% training, 10% validation and 80% test. I am curious though is it the same as Train + Prediction Split ?

Yes, until recently the split was always random. Now you can provide a custom split to the dataloader.

Hey Cane,

Passing stratified indices worked out pretty well. I had a huge dataset of 15M cells and a normal scVI run with fp16 precision and batch size of 2048 would take ~9 hours. I sketched the dataset to downsample while preserving the manifold and had that as an input to the train and validation set. Specifically, I downsampled to 20% of data which is 3M cells and split them half way across train and validation. Remaining 80% were passed as test. The model then took ~3 hours to converge. I compared the UMAPs of both the runs and they looked very similar to me.

Nice. Do you subset to the same cells for UMAP? Could you share losses? Plotting them would be sufficient. Scib-metrics would also be great. Though looking at UMAPs can be misleading, it gives me the impression that the second one actually looks better (no gradient from blue to orange cells etc) - not sure whether this is expected though.

Yes, the UMAP is on all cells. My understanding is that when you pass the data as test it’s never used in training but just passed through the model to get the latent space at the end when using model.get_latent_representation(). Here are the loss plots: