In reference to the default train() implementation, consider the case where adata contains:
- 40k cells → max_epochs=200
- 500 cells → max_epochs=400
My concern is that it seems like the risk of over-fitting would likely increase with fewer cells.
While the above behavior may follow lopez 2018, in which they note that “bigger datasets require fewer epochs”, it isn’t clear by what metric this was determined. In particular, whether they just used the training ELBO or whether they were able to diagnose [lack of] overfitting using eg marginal likelihood of a test set.
Oh, I just realized scvi.train.Trainer
has early stopping!
I guess the following should suffice:
model.train(early_stopping=True,early_stopping_monitor='reconstruction_loss_validation')
The heuristic is really just this – a heuristic. It allows scVI to give you something reasonable in less than hour for a large number of cells.
Indeed we also have early stopping. One thing to consider is that for larger datasets an epoch will take longer and thus the early stopping params would have to be changed. To the train methods you can add the param limit_train_batches
which gets passed through to PyTorch Lightning. It seems reasonable to me to limit the train batches such that each epoch is around 50k to 100k cells.
1 Like
thanks, that makes sense!