Loading an scVI model from a pytorch lightning checkpoint

I want to be able to evaluate the scvi model every n epochs on a benchmark. Therefore, I am trying to save the model every n epochs and loading it again in order to run the benchmark on each saved model. The only way I have found to save the model every n epochs is using the ModelCheckpoint callback and passing it to the train method. Loading from the saved checkpoint requires to convert the .ckpt file into a file that can be read by the scvi.model.SCVI.load method. Is this somehow possible or is there another way to achieve what I am trying to do?

Unfortunately, as far as I can tell, Lightning’s model checkpointing is not compatible with our load function out-of-the-box because we provide the PyTorch module as an init argument for our LightningModule wrapper. Some potential solutions that may be worth investigating:

  1. Subclassing ModelCheckpoint and overriding its methods (I think _save_checkpoint might work) to save the model in a way that is compatible with scvi-tools loading
  2. Running your benchmark within the training loop, depending on how complex it is. If the benchmark is simple, you can do this using scvi.train._callbacks.MetricsCallback, which takes in a set of functions (each of which takes in the model as an argument) to compute some metrics, which are then logged with Lightning.

Thank you for the response. I used the approach of subclassing the ModelCheckpoint class and overwriting the _save_checkpoint method. The only thing I struggled with is that saving the model during training leads to the model that is saved being untrained and hence not loadable. I used the work around that I change the trained parameter manually before and after saving the model. Is there another work around for the issue?

Hi @LeanderDiazBone @martinkim0, we’ve had a similar problem. We are running long-training jobs and would like to save at various checkpoints and compare model performance. Our workaround is as follows:

Do your model training with a checkpoint callback. It looks something as follows:

import pytorch_lightning.callbacks as pl_callbacks
import copy
import scvi


SCVI_TRAIN_KWARGS = {
    "max_epochs": 10,
    "use_gpu": False,
    "batch_size": 32,
    "enable_checkpointing": True,
}

# pytorch lightning checkpoint callback
model_checkpoint_callback = pl_callbacks.ModelCheckpoint(
        dirpath=model_dir,
        filename="model_{epoch}",
        every_n_train_steps=100,
        save_last=True,
        save_top_k=-1,
        verbose=True,
    )

scvi_train_kwargs = copy.deepcopy(SCVI_TRAIN_KWARGS)
scvi_train_kwargs["callbacks"] = [
    model_checkpoint_callback
]

scvi.model.SCVI.setup_anndata(adata, batch_key=BATCH_KEY, layer="counts")

model = scvi.model.SCVI(
    adata,
    n_latent=265,
    n_layers=2,
    encode_covariates=True,
)

model.train(**scvi_train_kwargs)
model.save(f"{model_dir}/model.pt", overwrite=True)

load the model, load the checkpoint, and apply the checkpoint to the model after some surgery:

import torch
from collections import OrderedDict

model.load(f"{model_dir}/model.pt", adata=adata)

checkpoint_path = f'{model_dir}/model_epoch=0.ckpt'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# go through the OrderedDict and trim "module." from the keys
state_dict_out = OrderedDict()
for k , v in checkpoint['state_dict'].items():
    state_dict_out[k.replace('module.','')] = v

checkpoint['state_dict'] = state_dict_out
model.module.load_state_dict(checkpoint['state_dict'])

This is obviously not an idea way to handle the state, but should get the weights in the right place. It is not known if there is some additional underlying state configuration that this method is missing.