I set up the totalVI model using mudata
scvi.model.TOTALVI.setup_mudata(
mdata,
rna_layer="counts",
protein_layer=None,
batch_key="batch",
modalities={
"rna_layer": "rna_subset",
"protein_layer": "protein",
"batch_key": "rna_subset",
},
)
But I can’t seem correctly save the model after training using scvi.model.TOTALVI.save()
scvi.model.TOTALVI.save(dir_path='model/model.pt')
Running the command reports the following error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/Users/ab/Downloads/vea_python/vea_python_velo/vea_totalvi_scvelo_no_hvg.ipynb Cell 34 line 1
----> 1 scvi.model.TOTALVI.save(dir_path='model/model.pt',save_anndata=True)
TypeError: BaseModelClass.save() missing 1 required positional argument: 'self'
I can save the model using torch model.save(‘model’) but I can’t load the model for scvi.
model.load_state_dict(torch.load('model/model.pt'))
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/Users/ab/Downloads/vea_python/vea_python_velo/Untitled-1.ipynb Cell 5 line 1
----> 1 model.load_state_dict(torch.load("model/model_old.pt"))
2 model.eval()
AttributeError: 'TOTALVI' object has no attribute 'load_state_dict'