DestVI: how to load a DestVI model that was saved as a pickle

Hi scvi team,

DestVI model was saved as a pickle as below:

DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()
st_model.train(max_epochs=2500)

with open("st_model.pkl", "wb") as f:
    pickle.dump(st_model, f)

And below error occurs

Traceback (most recent call last):
 File "/data/wanglab_mgberis/UT_scRNAseq/./cuda_to_cpu.py", line 39, in <module>
   st_adata.obsm["proportions"] = st_model.get_proportions()
 File "/usr/local/lib/python3.10/dist-packages/scvi/model/_destvi.py", line 198, in get_proportions
   stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size)
 File "/usr/local/lib/python3.10/dist-packages/scvi/model/base/_base_model.py", line 431, in _make_data_loader
   raise AssertionError(
AssertionError: AnnDataManager not found. Call `self._validate_anndata` prior to calling this function.

when i tried to load it to proceed downstream analysis as blow.

print('.1')
import pickle
with open("st_adata.pkl", "rb") as f:
    st_adata = pickle.load(f)

with open("sc_model.pkl", "rb") as f:
    sc_model = pickle.load(f)

print(sc_model)

DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()

with open("st_model.pkl", "rb") as f:
    st_model = pickle.load(f) 
    
print('.2')
st_model._validate_anndata(st_adata)

print('.3')
st_adata.obsm["proportions"] = st_model.get_proportions()

Is it possible i can still use this pickle instead of retraining it ?

While awaiting the response, I re-train it and save using below codes:

import tempfile
import destvi_utils
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import scvi
import seaborn as sns
import torch
from scvi.model import CondSCVI, DestVI
torch.set_float32_matmul_precision("high")
import pickle
with open("st_adata.pkl", "rb") as f:
    st_adata = pickle.load(f)
with open("sc_model.pkl", "rb") as f:
    sc_model = pickle.load(f)
print(sc_model)
DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()
st_model.train(max_epochs=2500)
print('Saving a model')
st_model.save("st_model.pt")

but scvi.model.DestVI.load still cannot load the model. Code and err msg shown below:

import pickle;
with open("st_adata.pkl", "rb") as f:
    st_adata = pickle.load(f)
DestVI.setup_anndata(st_adata, layer="counts")
st_model=scvi.model.DestVI.load('./st_model.pt',
                       adata=st_adata
                      )
INFO     File ./st_model.pt/model.pt already downloaded                                                            
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[17], line 4
     
      3 DestVI.setup_anndata(st_adata, layer="counts")
----> 4 st_model=scvi.model.DestVI.load('./st_model.pt',
      5                        adata=st_adata
      6                       )

File ~/mambaforge/envs/scvi-env/lib/python3.12/site-packages/scvi/model/base/_base_model.py:733, in BaseModelClass.load(cls, dir_path, adata, accelerator, device, prefix, backup_url)
    731 model = _initialize_model(cls, adata, attr_dict)
    732 model.module.on_load(model)
--> 733 model.module.load_state_dict(model_state_dict)
    735 model.to_device(device)
    736 model.module.eval()

File ~/mambaforge/envs/scvi-env/lib/python3.12/site-packages/torch/nn/modules/module.py:2584, in Module.load_state_dict(self, state_dict, strict, assign)
   2576         error_msgs.insert(
   2577             0,
   2578             "Missing key(s) in state_dict: {}. ".format(
   2579                 ", ".join(f'"{k}"' for k in missing_keys)
   2580             ),
   2581         )
   2583 if len(error_msgs) > 0:
-> 2584     raise RuntimeError(
   2585         "Error(s) in loading state_dict for {}:\n\t{}".format(
   2586             self.__class__.__name__, "\n\t".join(error_msgs)
   2587         )
   2588     )
   2589 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MRDeconv:
	Unexpected key(s) in state_dict: "pyro_param_store". 

could you please help?

Hey @jliu678
What version of SCVI are you running?
If its not the latest please update to the most recent version and try again as we had changes in the exact same spot of your error.

You can also save the model with save_anndata=True,overwrite=True and then wont need to load the adata from pickle.