What is the standard way of saving intermediate states during model training? We came up with a workaround (here) using PyTorch lightning callbacks as indicated in the documentation (here), but that largely seems unsupported and completely hacky.
Is there a way to periodically save checkpoints during training that is natively supported by scvi-tools, and also retains trainer state?
Hi @gregjohnso, thank you for your question. We don’t have a super straightforward way of saving intermediate checkpoints as of right now. I’ll see if I can implement a solution for this for our next release - for now, you can track the feature request here: Add changes to make saving intermediate checkpoints easier · Issue #2264 · scverse/scvi-tools · GitHub
Hi @gregjohnso, just wanted to update you that we now have an experimental SaveCheckpoint
callback that subclasses Lightning’s ModelCheckpoint
for compatibility with our model saves. You can enable this automatically by passing in enable_checkpointing=True
into most train
methods, or passing it in explicitly with the callbacks
argument. We are anticipating to release this by the end of the year with our 1.1 release. In the meantime, feel free to install from the main branch - feedback is appreciated!
awesome thank you @martinkim0 !