Hi scvi team,
DestVI model was saved as a pickle as below:
DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()
st_model.train(max_epochs=2500)
with open("st_model.pkl", "wb") as f:
pickle.dump(st_model, f)
And below error occurs
Traceback (most recent call last):
File "/data/wanglab_mgberis/UT_scRNAseq/./cuda_to_cpu.py", line 39, in <module>
st_adata.obsm["proportions"] = st_model.get_proportions()
File "/usr/local/lib/python3.10/dist-packages/scvi/model/_destvi.py", line 198, in get_proportions
stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size)
File "/usr/local/lib/python3.10/dist-packages/scvi/model/base/_base_model.py", line 431, in _make_data_loader
raise AssertionError(
AssertionError: AnnDataManager not found. Call `self._validate_anndata` prior to calling this function.
when i tried to load it to proceed downstream analysis as blow.
print('.1')
import pickle
with open("st_adata.pkl", "rb") as f:
st_adata = pickle.load(f)
with open("sc_model.pkl", "rb") as f:
sc_model = pickle.load(f)
print(sc_model)
DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()
with open("st_model.pkl", "rb") as f:
st_model = pickle.load(f)
print('.2')
st_model._validate_anndata(st_adata)
print('.3')
st_adata.obsm["proportions"] = st_model.get_proportions()
Is it possible i can still use this pickle instead of retraining it ?
While awaiting the response, I re-train it and save using below codes:
import tempfile
import destvi_utils
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import scvi
import seaborn as sns
import torch
from scvi.model import CondSCVI, DestVI
torch.set_float32_matmul_precision("high")
import pickle
with open("st_adata.pkl", "rb") as f:
st_adata = pickle.load(f)
with open("sc_model.pkl", "rb") as f:
sc_model = pickle.load(f)
print(sc_model)
DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()
st_model.train(max_epochs=2500)
print('Saving a model')
st_model.save("st_model.pt")
but scvi.model.DestVI.load still cannot load the model. Code and err msg shown below:
import pickle;
with open("st_adata.pkl", "rb") as f:
st_adata = pickle.load(f)
DestVI.setup_anndata(st_adata, layer="counts")
st_model=scvi.model.DestVI.load('./st_model.pt',
adata=st_adata
)
INFO File ./st_model.pt/model.pt already downloaded
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[17], line 4
3 DestVI.setup_anndata(st_adata, layer="counts")
----> 4 st_model=scvi.model.DestVI.load('./st_model.pt',
5 adata=st_adata
6 )
File ~/mambaforge/envs/scvi-env/lib/python3.12/site-packages/scvi/model/base/_base_model.py:733, in BaseModelClass.load(cls, dir_path, adata, accelerator, device, prefix, backup_url)
731 model = _initialize_model(cls, adata, attr_dict)
732 model.module.on_load(model)
--> 733 model.module.load_state_dict(model_state_dict)
735 model.to_device(device)
736 model.module.eval()
File ~/mambaforge/envs/scvi-env/lib/python3.12/site-packages/torch/nn/modules/module.py:2584, in Module.load_state_dict(self, state_dict, strict, assign)
2576 error_msgs.insert(
2577 0,
2578 "Missing key(s) in state_dict: {}. ".format(
2579 ", ".join(f'"{k}"' for k in missing_keys)
2580 ),
2581 )
2583 if len(error_msgs) > 0:
-> 2584 raise RuntimeError(
2585 "Error(s) in loading state_dict for {}:\n\t{}".format(
2586 self.__class__.__name__, "\n\t".join(error_msgs)
2587 )
2588 )
2589 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for MRDeconv:
Unexpected key(s) in state_dict: "pyro_param_store".
could you please help?