Hi. I am getting an error running DestVi (scvi-tools 1.3.0). I am basically importing a single cell reference AnnData, subsetting to keep just variable genes, and then finding shared genes with my spatial AnnData. I then fit the scLVM to this subset of my single cell reference data, and finally run DestVI on my spatial data using the single cell model. This is my code:
# Import single cell reference anndata
adata_ref= sc.read_h5ad('scadata_ref.h5ad')
# Filter cells with at least 10 genes
sc.pp.filter_genes(adata_ref, min_counts=10)
# Save raw counts
adata_ref.layers["counts"] = adata_ref.X.copy()
# Total-count normalize (library-size correct)
sc.pp.normalize_total(adata_ref, target_sum=1e4) # normalize to 10,000 reads per cell
# Apply log transformation
sc.pp.log1p(adata_ref)
# Find highly variable genes
sc.pp.highly_variable_genes(adata_ref, min_mean=0.0125, max_mean=3, min_disp=0.5)
# Store the normalized/log-transformed data for later use
adata_ref.raw = adata_ref
# Subset to keep just variable genes
adata_ref = adata_ref[:, adata_ref.var.highly_variable]
# Load spatial anndata
BLLES_W1= sc.read_h5ad("ALL_samples_adata.h5ad")
# Find shared genes with single cell reference anndata
intersect = np.intersect1d(adata_ref.var_names, BLLES_W1.var_names)
st_adata = BLLES_W1[:, intersect].copy()
sc_adata = adata_ref[:, intersect].copy()
# Prepare single cell reference anndata, use counts layer
CondSCVI.setup_anndata(sc_adata, layer="counts", batch_key="SampleCondition", labels_key="cell_type")
# fit model
sc_model = CondSCVI(sc_adata, weight_obs=False)
# train model
sc_model.train()
# Prepare spatial anndata
DestVI.setup_anndata(st_adata, layer="counts")
# Run DestVI
st_model = DestVI.from_rna_model(st_adata, sc_model)
RuntimeError Traceback (most recent call last)
Cell In[40], line 3
1 DestVI.setup_anndata(st_adata, layer="counts")
----> 3 st_model = DestVI.from_rna_model(st_adata, sc_model)
4 st_model.view_anndata_setup()
File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/scvi/model/_destvi.py:151, in DestVI.from_rna_model(cls, st_adata, sc_model, vamp_prior_p, l1_reg, **module_kwargs)
146 else:
147 mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior(
148 sc_model.adata, p=vamp_prior_p
149 )
--> 151 return cls(
152 st_adata,
153 mapping,
154 decoder_state_dict,
155 px_decoder_state_dict,
156 px_r,
157 sc_model.module.n_hidden,
158 sc_model.module.n_latent,
159 sc_model.module.n_layers,
160 mean_vprior=mean_vprior,
161 var_vprior=var_vprior,
162 mp_vprior=mp_vprior,
163 dropout_decoder=dropout_decoder,
164 l1_reg=l1_reg,
165 **module_kwargs,
166 )
File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/scvi/model/_destvi.py:93, in DestVI.__init__(self, st_adata, cell_type_mapping, decoder_state_dict, px_decoder_state_dict, px_r, n_hidden, n_latent, n_layers, dropout_decoder, l1_reg, **module_kwargs)
78 def __init__(
79 self,
80 st_adata: AnnData,
(...)
90 **module_kwargs,
91 ):
92 super().__init__(st_adata)
---> 93 self.module = self._module_cls(
94 n_spots=st_adata.n_obs,
95 n_labels=cell_type_mapping.shape[0],
96 decoder_state_dict=decoder_state_dict,
97 px_decoder_state_dict=px_decoder_state_dict,
98 px_r=px_r,
99 n_genes=st_adata.n_vars,
100 n_latent=n_latent,
101 n_layers=n_layers,
102 n_hidden=n_hidden,
103 dropout_decoder=dropout_decoder,
104 l1_reg=l1_reg,
105 **module_kwargs,
106 )
107 self.cell_type_mapping = cell_type_mapping
108 self._model_summary_string = "DestVI Model"
File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/scvi/module/_mrdeconv.py:130, in MRDeconv.__init__(self, n_spots, n_labels, n_hidden, n_layers, n_latent, n_genes, decoder_state_dict, px_decoder_state_dict, px_r, dropout_decoder, dropout_amortization, mean_vprior, var_vprior, mp_vprior, amortization, l1_reg, beta_reg, eta_reg, extra_encoder_kwargs, extra_decoder_kwargs)
126 self.px_decoder = torch.nn.Sequential(
127 torch.nn.Linear(n_hidden, n_genes), torch.nn.Softplus()
128 )
129 # don't compute gradient for those parameters
--> 130 self.decoder.load_state_dict(decoder_state_dict)
131 for param in self.decoder.parameters():
132 param.requires_grad = False
File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/torch/nn/modules/module.py:2593, in Module.load_state_dict(self, state_dict, strict, assign)
2585 error_msgs.insert(
2586 0,
2587 "Missing key(s) in state_dict: {}. ".format(
2588 ", ".join(f'"{k}"' for k in missing_keys)
2589 ),
2590 )
2592 if len(error_msgs) > 0:
-> 2593 raise RuntimeError(
2594 "Error(s) in loading state_dict for {}:\n\t{}".format(
2595 self.__class__.__name__, "\n\t".join(error_msgs)
2596 )
2597 )
2598 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for FCLayers:
size mismatch for fc_layers.Layer 0.0.weight: copying a param with shape torch.Size([128, 37]) from checkpoint, the shape in current model is torch.Size([128, 33]).
size mismatch for fc_layers.Layer 1.0.weight: copying a param with shape torch.Size([128, 160]) from checkpoint, the shape in current model is torch.Size([128, 156]).