DestVI.from_rna_model RuntimeError: Error(s) in loading state_dict for FCLayers

Hi. I am getting an error running DestVi (scvi-tools 1.3.0). I am basically importing a single cell reference AnnData, subsetting to keep just variable genes, and then finding shared genes with my spatial AnnData. I then fit the scLVM to this subset of my single cell reference data, and finally run DestVI on my spatial data using the single cell model. This is my code:

# Import single cell reference anndata
adata_ref= sc.read_h5ad('scadata_ref.h5ad')
# Filter cells with at least 10 genes
sc.pp.filter_genes(adata_ref, min_counts=10)
# Save raw counts
adata_ref.layers["counts"] = adata_ref.X.copy()
# Total-count normalize (library-size correct)
sc.pp.normalize_total(adata_ref, target_sum=1e4) # normalize to 10,000 reads per cell
# Apply log transformation
sc.pp.log1p(adata_ref)
# Find highly variable genes
sc.pp.highly_variable_genes(adata_ref, min_mean=0.0125, max_mean=3, min_disp=0.5)
# Store the normalized/log-transformed data for later use
adata_ref.raw = adata_ref
# Subset to keep just variable genes
adata_ref = adata_ref[:, adata_ref.var.highly_variable]


# Load spatial anndata
BLLES_W1= sc.read_h5ad("ALL_samples_adata.h5ad")

# Find shared genes with single cell reference anndata
intersect = np.intersect1d(adata_ref.var_names, BLLES_W1.var_names)
st_adata = BLLES_W1[:, intersect].copy()
sc_adata = adata_ref[:, intersect].copy()

# Prepare single cell reference anndata, use counts layer
CondSCVI.setup_anndata(sc_adata, layer="counts", batch_key="SampleCondition", labels_key="cell_type")
# fit model
sc_model = CondSCVI(sc_adata, weight_obs=False)
# train model
sc_model.train()

# Prepare spatial anndata
DestVI.setup_anndata(st_adata, layer="counts")
# Run DestVI
st_model = DestVI.from_rna_model(st_adata, sc_model)


RuntimeError                              Traceback (most recent call last)
Cell In[40], line 3
      1 DestVI.setup_anndata(st_adata, layer="counts")
----> 3 st_model = DestVI.from_rna_model(st_adata, sc_model)
      4 st_model.view_anndata_setup()

File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/scvi/model/_destvi.py:151, in DestVI.from_rna_model(cls, st_adata, sc_model, vamp_prior_p, l1_reg, **module_kwargs)
    146 else:
    147     mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior(
    148         sc_model.adata, p=vamp_prior_p
    149     )
--> 151 return cls(
    152     st_adata,
    153     mapping,
    154     decoder_state_dict,
    155     px_decoder_state_dict,
    156     px_r,
    157     sc_model.module.n_hidden,
    158     sc_model.module.n_latent,
    159     sc_model.module.n_layers,
    160     mean_vprior=mean_vprior,
    161     var_vprior=var_vprior,
    162     mp_vprior=mp_vprior,
    163     dropout_decoder=dropout_decoder,
    164     l1_reg=l1_reg,
    165     **module_kwargs,
    166 )

File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/scvi/model/_destvi.py:93, in DestVI.__init__(self, st_adata, cell_type_mapping, decoder_state_dict, px_decoder_state_dict, px_r, n_hidden, n_latent, n_layers, dropout_decoder, l1_reg, **module_kwargs)
     78 def __init__(
     79     self,
     80     st_adata: AnnData,
   (...)
     90     **module_kwargs,
     91 ):
     92     super().__init__(st_adata)
---> 93     self.module = self._module_cls(
     94         n_spots=st_adata.n_obs,
     95         n_labels=cell_type_mapping.shape[0],
     96         decoder_state_dict=decoder_state_dict,
     97         px_decoder_state_dict=px_decoder_state_dict,
     98         px_r=px_r,
     99         n_genes=st_adata.n_vars,
    100         n_latent=n_latent,
    101         n_layers=n_layers,
    102         n_hidden=n_hidden,
    103         dropout_decoder=dropout_decoder,
    104         l1_reg=l1_reg,
    105         **module_kwargs,
    106     )
    107     self.cell_type_mapping = cell_type_mapping
    108     self._model_summary_string = "DestVI Model"

File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/scvi/module/_mrdeconv.py:130, in MRDeconv.__init__(self, n_spots, n_labels, n_hidden, n_layers, n_latent, n_genes, decoder_state_dict, px_decoder_state_dict, px_r, dropout_decoder, dropout_amortization, mean_vprior, var_vprior, mp_vprior, amortization, l1_reg, beta_reg, eta_reg, extra_encoder_kwargs, extra_decoder_kwargs)
    126 self.px_decoder = torch.nn.Sequential(
    127     torch.nn.Linear(n_hidden, n_genes), torch.nn.Softplus()
    128 )
    129 # don't compute gradient for those parameters
--> 130 self.decoder.load_state_dict(decoder_state_dict)
    131 for param in self.decoder.parameters():
    132     param.requires_grad = False

File ~/miniconda3/envs/spatial/lib/python3.11/site-packages/torch/nn/modules/module.py:2593, in Module.load_state_dict(self, state_dict, strict, assign)
   2585         error_msgs.insert(
   2586             0,
   2587             "Missing key(s) in state_dict: {}. ".format(
   2588                 ", ".join(f'"{k}"' for k in missing_keys)
   2589             ),
   2590         )
   2592 if len(error_msgs) > 0:
-> 2593     raise RuntimeError(
   2594         "Error(s) in loading state_dict for {}:\n\t{}".format(
   2595             self.__class__.__name__, "\n\t".join(error_msgs)
   2596         )
   2597     )
   2598 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for FCLayers:
	size mismatch for fc_layers.Layer 0.0.weight: copying a param with shape torch.Size([128, 37]) from checkpoint, the shape in current model is torch.Size([128, 33]).
	size mismatch for fc_layers.Layer 1.0.weight: copying a param with shape torch.Size([128, 160]) from checkpoint, the shape in current model is torch.Size([128, 156]).

Hey,
You can perhaps use batch_key=“SampleCondition” for your pre and post process analysis of the data, but the current destVI model requires that its scLVM model will not run with batch_key provided, therefore remove the “batch_key” from CondSCVI.setup_anndata and try again.

batch_key applies only to the scRNA, not the spatial data. But this is something we might add in a future upgrade of destVI model (then you will add it to the destVI setup_anndata - but its not ready yet).