_base_model.load() throwing weird errors when loading saved model

Prior to training, I can set up the data for my model without error:

axm.ImKs.setup_anndata(
    adata_axm, 
    batch_key='pid',
    categorical_covariate_keys=['leiden40','leiden20','leiden10','leiden05'],
    ds_label_keys=['downsampled'],
    layer=f'counts'
)

However, when I try to re-load the data and model after training, I am getting the following error, which seems new:

adata_axm_reload = anndata.read_h5ad(f"{adata_dir}/adata.h5ad")
trained_axm = axm.ImKs.load(scvi_dir_trained,adata_axm_reload)

File /workspaces/scvi_work/axm/axm/_mymodel.py:2258, in ImKs.setup_anndata(cls, adata, layer, batch_key, size_factor_key, categorical_covariate_keys, ds_label_keys, continuous_covariate_keys, **kwargs)
   2235 anndata_fields = [
   2236     LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
   2237     batch_field,
   (...)
   2252     NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
   2253 ]
   2255 adata_manager = AnnDataManager(
   2256     fields=anndata_fields, setup_method_args=setup_method_args
...
--> 217 mapping = state_registry[self.CATEGORICAL_MAPPING_KEY].copy()
    219 # extend mapping for new categories
    220 for c in np.unique(self._get_original_column(adata_target)):

KeyError: 'categorical_mapping'

Can you share your custom setup anndata?

sure, it’s [almost] identical to the multivi setup_anndata without protein:

    @classmethod
    @setup_anndata_dsp.dedent
    def setup_anndata(
        cls,
        adata: AnnData,
        layer: Optional[str] = None,
        batch_key: Optional[str] = None,
        size_factor_key: Optional[str] = None,
        categorical_covariate_keys: Optional[List[str]] = None,
        ds_label_keys: Optional[List[str]] = None,
        continuous_covariate_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        """
        %(summary)s.

        Parameters
        ----------
        %(param_layer)s
        %(param_batch_key)s
        %(param_size_factor_key)s
        %(param_cat_cov_keys)s
        %(param_cont_cov_keys)s
        """
        setup_method_args = cls._get_setup_method_args(**locals())
        adata.obs["_indices"] = np.arange(adata.n_obs)
        batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            batch_field,
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            NumericalObsField(
                REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False
            ),
            CategoricalJointObsField(
                REGISTRY_KEYS.LABELS_KEY, ds_label_keys
            ),
            CategoricalJointObsField(
                REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
            ),
            NumericalJointObsField(
                REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
            ),
            NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
        ]

        adata_manager = AnnDataManager(
            fields=anndata_fields, setup_method_args=setup_method_args
        )
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

    def _check_adata_modality_weights(self, adata):
        """
        Checks if adata is None and weights are per cell.

        :param adata: anndata object
        :return:
        """
        if (adata is not None) and (self.module.modality_weights == "cell"):
            raise RuntimeError(
                "Held out data not permitted when using per cell weights"
            )

Can you provide the full traceback? And did you change anything wr.t. the save function?

I get save() from BaseModelClass

Here is the trace:

{
	"name": "KeyError",
	"message": "'categorical_mapping'",
	"stack": "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)\nCell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m adata_mvi \u001b[39m=\u001b[39m anndata\u001b[39m.\u001b[39mread_h5ad(\u001b[39m\"\u001b[39m\u001b[39m/workspaces/scvi_work/saved_adata/test_classifier/axm1000_ep4/ds1.0_leiden05_cl0_sd0_ss10000_sseed1_alpha-0.0/adata_multi.h5ad\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m mvi \u001b[39m=\u001b[39m axm\u001b[39m.\u001b[39;49mImKs\u001b[39m.\u001b[39;49mload(dir_path\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39m/workspaces/scvi_work/saved_models/trained/test_classifier/axm1000_ep4/ds1.0_leiden05_cl0_sd0_ss10000_sseed1_alpha-0.0\u001b[39;49m\u001b[39m\"\u001b[39;49m, adata\u001b[39m=\u001b[39;49madata_mvi)\n\nFile \u001b[0;32m/workspaces/scvi_work/scvi-tools/scvi/model/base/_base_model.py:674\u001b[0m, in \u001b[0;36mBaseModelClass.load\u001b[0;34m(cls, dir_path, adata, use_gpu, prefix, backup_url)\u001b[0m\n\u001b[1;32m    670\u001b[0m \u001b[39m# Calling ``setup_anndata`` method with the original arguments passed into\u001b[39;00m\n\u001b[1;32m    671\u001b[0m \u001b[39m# the saved model. This enables simple backwards compatibility in the case of\u001b[39;00m\n\u001b[1;32m    672\u001b[0m \u001b[39m# newly introduced fields or parameters.\u001b[39;00m\n\u001b[1;32m    673\u001b[0m method_name \u001b[39m=\u001b[39m registry\u001b[39m.\u001b[39mget(_SETUP_METHOD_NAME, \u001b[39m\"\u001b[39m\u001b[39msetup_anndata\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 674\u001b[0m \u001b[39mgetattr\u001b[39;49m(\u001b[39mcls\u001b[39;49m, method_name)(\n\u001b[1;32m    675\u001b[0m     adata, source_registry\u001b[39m=\u001b[39;49mregistry, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mregistry[_SETUP_ARGS_KEY]\n\u001b[1;32m    676\u001b[0m )\n\u001b[1;32m    678\u001b[0m model \u001b[39m=\u001b[39m _initialize_model(\u001b[39mcls\u001b[39m, adata, attr_dict)\n\u001b[1;32m    679\u001b[0m model\u001b[39m.\u001b[39mmodule\u001b[39m.\u001b[39mon_load(model)\n\nFile \u001b[0;32m/workspaces/scvi_work/axm/axm/_mymodel.py:2248\u001b[0m, in \u001b[0;36mImKs.setup_anndata\u001b[0;34m(cls, adata, layer, batch_key, size_factor_key, categorical_covariate_keys, ds_label_keys, continuous_covariate_keys, **kwargs)\u001b[0m\n\u001b[1;32m   2225\u001b[0m anndata_fields \u001b[39m=\u001b[39m [\n\u001b[1;32m   2226\u001b[0m     LayerField(REGISTRY_KEYS\u001b[39m.\u001b[39mX_KEY, layer, is_count_data\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m),\n\u001b[1;32m   2227\u001b[0m     batch_field,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   2242\u001b[0m     NumericalObsField(REGISTRY_KEYS\u001b[39m.\u001b[39mINDICES_KEY, \u001b[39m\"\u001b[39m\u001b[39m_indices\u001b[39m\u001b[39m\"\u001b[39m),\n\u001b[1;32m   2243\u001b[0m ]\n\u001b[1;32m   2245\u001b[0m adata_manager \u001b[39m=\u001b[39m AnnDataManager(\n\u001b[1;32m   2246\u001b[0m     fields\u001b[39m=\u001b[39manndata_fields, setup_method_args\u001b[39m=\u001b[39msetup_method_args\n\u001b[1;32m   2247\u001b[0m )\n\u001b[0;32m-> 2248\u001b[0m adata_manager\u001b[39m.\u001b[39;49mregister_fields(adata, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   2249\u001b[0m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mregister_manager(adata_manager)\n\nFile \u001b[0;32m/workspaces/scvi_work/scvi-tools/scvi/data/_manager.py:179\u001b[0m, in \u001b[0;36mAnnDataManager.register_fields\u001b[0;34m(self, adata, source_registry, **transfer_kwargs)\u001b[0m\n\u001b[1;32m    176\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_validate_anndata_object(adata)\n\u001b[1;32m    178\u001b[0m \u001b[39mfor\u001b[39;00m field \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfields:\n\u001b[0;32m--> 179\u001b[0m     \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_add_field(\n\u001b[1;32m    180\u001b[0m         field\u001b[39m=\u001b[39;49mfield,\n\u001b[1;32m    181\u001b[0m         adata\u001b[39m=\u001b[39;49madata,\n\u001b[1;32m    182\u001b[0m         source_registry\u001b[39m=\u001b[39;49msource_registry,\n\u001b[1;32m    183\u001b[0m         \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mtransfer_kwargs,\n\u001b[1;32m    184\u001b[0m     )\n\u001b[1;32m    186\u001b[0m \u001b[39m# Save arguments for register_fields.\u001b[39;00m\n\u001b[1;32m    187\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_source_registry \u001b[39m=\u001b[39m deepcopy(source_registry)\n\nFile \u001b[0;32m/workspaces/scvi_work/scvi-tools/scvi/data/_manager.py:214\u001b[0m, in \u001b[0;36mAnnDataManager._add_field\u001b[0;34m(self, field, adata, source_registry, **transfer_kwargs)\u001b[0m\n\u001b[1;32m    211\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m field\u001b[39m.\u001b[39mis_empty:\n\u001b[1;32m    212\u001b[0m     \u001b[39m# Transfer case: Source registry is used for validation and/or setup.\u001b[39;00m\n\u001b[1;32m    213\u001b[0m     \u001b[39mif\u001b[39;00m source_registry \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 214\u001b[0m         field_registry[_constants\u001b[39m.\u001b[39m_STATE_REGISTRY_KEY] \u001b[39m=\u001b[39m field\u001b[39m.\u001b[39;49mtransfer_field(\n\u001b[1;32m    215\u001b[0m             source_registry[_constants\u001b[39m.\u001b[39;49m_FIELD_REGISTRIES_KEY][\n\u001b[1;32m    216\u001b[0m                 field\u001b[39m.\u001b[39;49mregistry_key\n\u001b[1;32m    217\u001b[0m             ][_constants\u001b[39m.\u001b[39;49m_STATE_REGISTRY_KEY],\n\u001b[1;32m    218\u001b[0m             adata,\n\u001b[1;32m    219\u001b[0m             \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mtransfer_kwargs,\n\u001b[1;32m    220\u001b[0m         )\n\u001b[1;32m    221\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    222\u001b[0m         field_registry[_constants\u001b[39m.\u001b[39m_STATE_REGISTRY_KEY] \u001b[39m=\u001b[39m field\u001b[39m.\u001b[39mregister_field(\n\u001b[1;32m    223\u001b[0m             adata\n\u001b[1;32m    224\u001b[0m         )\n\nFile \u001b[0;32m/workspaces/scvi_work/scvi-tools/scvi/data/fields/_dataframe_field.py:217\u001b[0m, in \u001b[0;36mCategoricalDataFrameField.transfer_field\u001b[0;34m(self, state_registry, adata_target, extend_categories, **kwargs)\u001b[0m\n\u001b[1;32m    213\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_setup_default_attr(adata_target)\n\u001b[1;32m    215\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvalidate_field(adata_target)\n\u001b[0;32m--> 217\u001b[0m mapping \u001b[39m=\u001b[39m state_registry[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mCATEGORICAL_MAPPING_KEY]\u001b[39m.\u001b[39mcopy()\n\u001b[1;32m    219\u001b[0m \u001b[39m# extend mapping for new categories\u001b[39;00m\n\u001b[1;32m    220\u001b[0m \u001b[39mfor\u001b[39;00m c \u001b[39min\u001b[39;00m np\u001b[39m.\u001b[39munique(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_original_column(adata_target)):\n\n\u001b[0;31mKeyError\u001b[0m: 'categorical_mapping'"
}