How to save intermediate checkpoints?

What is the standard way of saving intermediate states during model training? We came up with a workaround (here) using PyTorch lightning callbacks as indicated in the documentation (here), but that largely seems unsupported and completely hacky.

Is there a way to periodically save checkpoints during training that is natively supported by scvi-tools, and also retains trainer state?

Hi @gregjohnso, thank you for your question. We don’t have a super straightforward way of saving intermediate checkpoints as of right now. I’ll see if I can implement a solution for this for our next release - for now, you can track the feature request here: Add changes to make saving intermediate checkpoints easier · Issue #2264 · scverse/scvi-tools · GitHub

Hi @gregjohnso, just wanted to update you that we now have an experimental SaveCheckpoint callback that subclasses Lightning’s ModelCheckpoint for compatibility with our model saves. You can enable this automatically by passing in enable_checkpointing=True into most train methods, or passing it in explicitly with the callbacks argument. We are anticipating to release this by the end of the year with our 1.1 release. In the meantime, feel free to install from the main branch - feedback is appreciated!

awesome thank you @martinkim0 !