Loading an scVI model from a pytorch lightning checkpoint

Hi @LeanderDiazBone @martinkim0, we’ve had a similar problem. We are running long-training jobs and would like to save at various checkpoints and compare model performance. Our workaround is as follows:

Do your model training with a checkpoint callback. It looks something as follows:

import pytorch_lightning.callbacks as pl_callbacks
import copy
import scvi


SCVI_TRAIN_KWARGS = {
    "max_epochs": 10,
    "use_gpu": False,
    "batch_size": 32,
    "enable_checkpointing": True,
}

# pytorch lightning checkpoint callback
model_checkpoint_callback = pl_callbacks.ModelCheckpoint(
        dirpath=model_dir,
        filename="model_{epoch}",
        every_n_train_steps=100,
        save_last=True,
        save_top_k=-1,
        verbose=True,
    )

scvi_train_kwargs = copy.deepcopy(SCVI_TRAIN_KWARGS)
scvi_train_kwargs["callbacks"] = [
    model_checkpoint_callback
]

scvi.model.SCVI.setup_anndata(adata, batch_key=BATCH_KEY, layer="counts")

model = scvi.model.SCVI(
    adata,
    n_latent=265,
    n_layers=2,
    encode_covariates=True,
)

model.train(**scvi_train_kwargs)
model.save(f"{model_dir}/model.pt", overwrite=True)

load the model, load the checkpoint, and apply the checkpoint to the model after some surgery:

import torch
from collections import OrderedDict

model.load(f"{model_dir}/model.pt", adata=adata)

checkpoint_path = f'{model_dir}/model_epoch=0.ckpt'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# go through the OrderedDict and trim "module." from the keys
state_dict_out = OrderedDict()
for k , v in checkpoint['state_dict'].items():
    state_dict_out[k.replace('module.','')] = v

checkpoint['state_dict'] = state_dict_out
model.module.load_state_dict(checkpoint['state_dict'])

This is obviously not an idea way to handle the state, but should get the weights in the right place. It is not known if there is some additional underlying state configuration that this method is missing.