How to properly save and load a TotalVI trained model that's based on mudata?

I set up the totalVI model using mudata

scvi.model.TOTALVI.setup_mudata(
    mdata,
    rna_layer="counts",
    protein_layer=None,
    batch_key="batch",
    modalities={
        "rna_layer": "rna_subset",
        "protein_layer": "protein",
        "batch_key": "rna_subset",
    },
)

But I can’t seem correctly save the model after training using scvi.model.TOTALVI.save()

scvi.model.TOTALVI.save(dir_path='model/model.pt')

Running the command reports the following error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/Users/ab/Downloads/vea_python/vea_python_velo/vea_totalvi_scvelo_no_hvg.ipynb Cell 34 line 1
----> 1 scvi.model.TOTALVI.save(dir_path='model/model.pt',save_anndata=True)

TypeError: BaseModelClass.save() missing 1 required positional argument: 'self'

I can save the model using torch model.save(‘model’) but I can’t load the model for scvi.

model.load_state_dict(torch.load('model/model.pt'))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/Users/ab/Downloads/vea_python/vea_python_velo/Untitled-1.ipynb Cell 5 line 1
----> 1 model.load_state_dict(torch.load("model/model_old.pt"))
      2 model.eval()

AttributeError: 'TOTALVI' object has no attribute 'load_state_dict'

Hi, you can call save directly on the model instance that you created:

model = scvi.model.TOTALVI(...)
model.train(...)

model.save(save_path)

You can then load it back in as follows:

model = scvi.model.TOTALVI.load(save_path, adata=mudata)
1 Like