Error when trying to use scvi.model.SCANVI.from_scvi_model

Hi I am trying to do label transfer using scvi, however when I try to run

scanvi_model = scvi.model.SCANVI.from_scvi_model(
    scvi_model,
    adata=reference_adata,
    unlabeled_category="Unknown",
    labels_key=SCANVI_CELLTYPE_KEY,
)

I got the following error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 1
----> 1 scanvi_model = scvi.model.SCANVI.from_scvi_model(
      2     scvi_model,
      3     adata=reference_adata,
      4     unlabeled_category="Unknown",
      5     labels_key=SCANVI_CELLTYPE_KEY,
      6 )

File scanvi/label_transfer/lib/python3.11/site-packages/scvi/model/_scanvi.py:244, in SCANVI.from_scvi_model(cls, scvi_model, unlabeled_category, labels_key, adata, **scanvi_kwargs)
    242 if scvi_labels_key is None:
    243     scvi_setup_args.update({"labels_key": labels_key})
--> 244 cls.setup_anndata(
    245     adata,
    246     unlabeled_category=unlabeled_category,
    247     **scvi_setup_args,
    248 )
    249 scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)
    250 scvi_state_dict = scvi_model.module.state_dict()

File scanvi/label_transfer/lib/python3.11/site-packages/scvi/model/_scanvi.py:481, in SCANVI.setup_anndata(cls, adata, labels_key, unlabeled_category, layer, batch_key, size_factor_key, categorical_covariate_keys, continuous_covariate_keys, **kwargs)
    479     anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
    480 adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
--> 481 adata_manager.register_fields(adata, **kwargs)
    482 cls.register_manager(adata_manager)

File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/_manager.py:174, in AnnDataManager.register_fields(self, adata, source_registry, **transfer_kwargs)
    171 self._validate_anndata_object(adata)
    173 for field in self.fields:
--> 174     self._add_field(
    175         field=field,
    176         adata=adata,
    177         source_registry=source_registry,
    178         **transfer_kwargs,
    179     )
    181 # Save arguments for register_fields.
    182 self._source_registry = deepcopy(source_registry)

File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/_manager.py:217, in AnnDataManager._add_field(self, field, adata, source_registry, **transfer_kwargs)
    209         field_registry[_constants._STATE_REGISTRY_KEY] = field.transfer_field(
    210             source_registry[_constants._FIELD_REGISTRIES_KEY][field.registry_key][
    211                 _constants._STATE_REGISTRY_KEY
   (...)
    214             **transfer_kwargs,
    215         )
    216     else:
--> 217         field_registry[_constants._STATE_REGISTRY_KEY] = field.register_field(adata)
    218 # Compute and set summary stats for the given field.
    219 state_registry = field_registry[_constants._STATE_REGISTRY_KEY]

File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/fields/_scanvi.py:73, in LabelsWithUnlabeledObsField.register_field(self, adata)
     71 state_registry = super().register_field(adata)
     72 mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
---> 73 return self._remap_unlabeled_to_final_category(adata, mapping)

File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/fields/_scanvi.py:56, in LabelsWithUnlabeledObsField._remap_unlabeled_to_final_category(self, adata, mapping)
     54 cat_dtype = CategoricalDtype(categories=mapping, ordered=True)
     55 # rerun setup for the batch column
---> 56 mapping = _make_column_categorical(
     57     adata.obs,
     58     self._original_attr_key,
     59     self.attr_key,
     60     categorical_dtype=cat_dtype,
     61 )
     63 return {
     64     self.CATEGORICAL_MAPPING_KEY: mapping,
     65     self.ORIGINAL_ATTR_KEY: self._original_attr_key,
     66     self.UNLABELED_CATEGORY: self._unlabeled_category,
     67 }

File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/_utils.py:219, in _make_column_categorical(df, column_key, alternate_column_key, categorical_dtype)
    217 if -1 in unique:
    218     received_categories = df[column_key].astype("category").cat.categories
--> 219     raise ValueError(
    220         f'Making .obs["{column_key}"] categorical failed. Expected categories: {mapping}. '
    221         f"Received categories: {received_categories}. "
    222     )
    223 df[alternate_column_key] = codes
    225 # make sure each category contains enough cells

ValueError: Making .obs["_scvi_labels"] categorical failed. Expected categories: ['0' 'Unknown']. Received categories: Int64Index([0], dtype='int64'). 

My reference_adata looks like this

AnnData object with n_obs × n_vars = 601056 × 5000
    obs: 'sex', 'tissue', 'ethnicity', 'disease', 'assay', 'assay_ontology_term_id', 'sample_id', 'donor_id', 'dataset_id', 'development_stage', 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_counts', '_scvi_batch', '_scvi_labels', 'celltype_scanvi'
    var: 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: '_scvi', 'hvg', 'log1p', '_scvi_uuid', '_scvi_manager_uuid'
    layers: 'counts'

and this is how I load my scvi_model

scvi_model = scvi.model.SCVI.load("model/model_COVID19_reference_atlas_scvi0.16.2", reference_adata)

and this is the version of the packages installed in the environment that I created

absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
anndata==0.10.7
appnope==0.1.4
array_api_compat==1.7.1
asttokens==2.4.1
attrs==23.2.0
chex==0.1.86
comm==0.2.2
contextlib2==21.6.0
contourpy==1.2.1
cycler==0.12.1
debugpy==1.8.1
decorator==5.1.1
docrep==0.3.2
et-xmlfile==1.1.0
etils==1.9.1
executing==2.0.1
filelock==3.14.0
flax==0.8.4
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.6.0
future==1.0.0
grpcio==1.64.1
h5py==3.11.0
idna==3.7
importlib_resources==6.4.0
IProgress==0.4
ipykernel==6.29.4
ipython==8.25.0
ipywidgets==8.1.3
jax==0.4.28
jaxlib==0.4.28
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyterlab_widgets==3.0.11
kiwisolver==1.4.5
legacy-api-wrap==1.4
lightning==2.1.4
lightning-utilities==0.11.2
llvmlite==0.42.0
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
matplotlib-inline==0.1.7
mdurl==0.1.2
ml-dtypes==0.4.0
ml_collections==0.1.1
mpmath==1.3.0
msgpack==1.0.8
mudata==0.2.3
multidict==6.0.5
multipledispatch==1.0.0
natsort==8.4.0
nest-asyncio==1.6.0
networkx==3.3
numba==0.59.1
numpy==1.26.4
numpyro==0.15.0
openpyxl==3.1.3
opt-einsum==3.3.0
optax==0.2.2
orbax-checkpoint==0.5.15
packaging==24.0
pandas==1.5.3
parso==0.8.4
patsy==0.5.6
pexpect==4.9.0
pillow==10.3.0
platformdirs==4.2.2
prompt_toolkit==3.0.46
protobuf==5.27.0
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyDeprecate==0.3.1
Pygments==2.18.0
pynndescent==0.5.12
pyparsing==3.1.2
pyro-api==0.1.2
pyro-ppl==1.9.1
python-dateutil==2.9.0.post0
pytorch-lightning==1.5.10
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.0.3
rich==13.7.1
scanpy==1.10.1
scikit-learn==1.5.0
scipy==1.13.1
scvi-tools==1.1.2
seaborn==0.13.2
session_info==1.0.0
six==1.16.0
stack-data==0.6.3
statsmodels==0.14.2
stdlib-list==0.10.0
sympy==1.12.1
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorstore==0.1.60
threadpoolctl==3.5.0
toolz==0.12.1
torch==2.3.1
torchmetrics==1.4.0.post0
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
typing_extensions==4.12.1
tzdata==2024.1
umap-learn==0.5.6
wcwidth==0.2.13
Werkzeug==3.0.3
widgetsnbextension==4.0.11
yarl==1.9.4
zipp==3.19.2

Thanks in advance, for the help