Hi I am trying to do label transfer using scvi, however when I try to run
scanvi_model = scvi.model.SCANVI.from_scvi_model(
scvi_model,
adata=reference_adata,
unlabeled_category="Unknown",
labels_key=SCANVI_CELLTYPE_KEY,
)
I got the following error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[8], line 1
----> 1 scanvi_model = scvi.model.SCANVI.from_scvi_model(
2 scvi_model,
3 adata=reference_adata,
4 unlabeled_category="Unknown",
5 labels_key=SCANVI_CELLTYPE_KEY,
6 )
File scanvi/label_transfer/lib/python3.11/site-packages/scvi/model/_scanvi.py:244, in SCANVI.from_scvi_model(cls, scvi_model, unlabeled_category, labels_key, adata, **scanvi_kwargs)
242 if scvi_labels_key is None:
243 scvi_setup_args.update({"labels_key": labels_key})
--> 244 cls.setup_anndata(
245 adata,
246 unlabeled_category=unlabeled_category,
247 **scvi_setup_args,
248 )
249 scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)
250 scvi_state_dict = scvi_model.module.state_dict()
File scanvi/label_transfer/lib/python3.11/site-packages/scvi/model/_scanvi.py:481, in SCANVI.setup_anndata(cls, adata, labels_key, unlabeled_category, layer, batch_key, size_factor_key, categorical_covariate_keys, continuous_covariate_keys, **kwargs)
479 anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
480 adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
--> 481 adata_manager.register_fields(adata, **kwargs)
482 cls.register_manager(adata_manager)
File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/_manager.py:174, in AnnDataManager.register_fields(self, adata, source_registry, **transfer_kwargs)
171 self._validate_anndata_object(adata)
173 for field in self.fields:
--> 174 self._add_field(
175 field=field,
176 adata=adata,
177 source_registry=source_registry,
178 **transfer_kwargs,
179 )
181 # Save arguments for register_fields.
182 self._source_registry = deepcopy(source_registry)
File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/_manager.py:217, in AnnDataManager._add_field(self, field, adata, source_registry, **transfer_kwargs)
209 field_registry[_constants._STATE_REGISTRY_KEY] = field.transfer_field(
210 source_registry[_constants._FIELD_REGISTRIES_KEY][field.registry_key][
211 _constants._STATE_REGISTRY_KEY
(...)
214 **transfer_kwargs,
215 )
216 else:
--> 217 field_registry[_constants._STATE_REGISTRY_KEY] = field.register_field(adata)
218 # Compute and set summary stats for the given field.
219 state_registry = field_registry[_constants._STATE_REGISTRY_KEY]
File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/fields/_scanvi.py:73, in LabelsWithUnlabeledObsField.register_field(self, adata)
71 state_registry = super().register_field(adata)
72 mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
---> 73 return self._remap_unlabeled_to_final_category(adata, mapping)
File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/fields/_scanvi.py:56, in LabelsWithUnlabeledObsField._remap_unlabeled_to_final_category(self, adata, mapping)
54 cat_dtype = CategoricalDtype(categories=mapping, ordered=True)
55 # rerun setup for the batch column
---> 56 mapping = _make_column_categorical(
57 adata.obs,
58 self._original_attr_key,
59 self.attr_key,
60 categorical_dtype=cat_dtype,
61 )
63 return {
64 self.CATEGORICAL_MAPPING_KEY: mapping,
65 self.ORIGINAL_ATTR_KEY: self._original_attr_key,
66 self.UNLABELED_CATEGORY: self._unlabeled_category,
67 }
File scanvi/label_transfer/lib/python3.11/site-packages/scvi/data/_utils.py:219, in _make_column_categorical(df, column_key, alternate_column_key, categorical_dtype)
217 if -1 in unique:
218 received_categories = df[column_key].astype("category").cat.categories
--> 219 raise ValueError(
220 f'Making .obs["{column_key}"] categorical failed. Expected categories: {mapping}. '
221 f"Received categories: {received_categories}. "
222 )
223 df[alternate_column_key] = codes
225 # make sure each category contains enough cells
ValueError: Making .obs["_scvi_labels"] categorical failed. Expected categories: ['0' 'Unknown']. Received categories: Int64Index([0], dtype='int64').
My reference_adata looks like this
AnnData object with n_obs × n_vars = 601056 × 5000
obs: 'sex', 'tissue', 'ethnicity', 'disease', 'assay', 'assay_ontology_term_id', 'sample_id', 'donor_id', 'dataset_id', 'development_stage', 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_counts', '_scvi_batch', '_scvi_labels', 'celltype_scanvi'
var: 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
uns: '_scvi', 'hvg', 'log1p', '_scvi_uuid', '_scvi_manager_uuid'
layers: 'counts'
and this is how I load my scvi_model
scvi_model = scvi.model.SCVI.load("model/model_COVID19_reference_atlas_scvi0.16.2", reference_adata)
and this is the version of the packages installed in the environment that I created
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
anndata==0.10.7
appnope==0.1.4
array_api_compat==1.7.1
asttokens==2.4.1
attrs==23.2.0
chex==0.1.86
comm==0.2.2
contextlib2==21.6.0
contourpy==1.2.1
cycler==0.12.1
debugpy==1.8.1
decorator==5.1.1
docrep==0.3.2
et-xmlfile==1.1.0
etils==1.9.1
executing==2.0.1
filelock==3.14.0
flax==0.8.4
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.6.0
future==1.0.0
grpcio==1.64.1
h5py==3.11.0
idna==3.7
importlib_resources==6.4.0
IProgress==0.4
ipykernel==6.29.4
ipython==8.25.0
ipywidgets==8.1.3
jax==0.4.28
jaxlib==0.4.28
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyterlab_widgets==3.0.11
kiwisolver==1.4.5
legacy-api-wrap==1.4
lightning==2.1.4
lightning-utilities==0.11.2
llvmlite==0.42.0
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
matplotlib-inline==0.1.7
mdurl==0.1.2
ml-dtypes==0.4.0
ml_collections==0.1.1
mpmath==1.3.0
msgpack==1.0.8
mudata==0.2.3
multidict==6.0.5
multipledispatch==1.0.0
natsort==8.4.0
nest-asyncio==1.6.0
networkx==3.3
numba==0.59.1
numpy==1.26.4
numpyro==0.15.0
openpyxl==3.1.3
opt-einsum==3.3.0
optax==0.2.2
orbax-checkpoint==0.5.15
packaging==24.0
pandas==1.5.3
parso==0.8.4
patsy==0.5.6
pexpect==4.9.0
pillow==10.3.0
platformdirs==4.2.2
prompt_toolkit==3.0.46
protobuf==5.27.0
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyDeprecate==0.3.1
Pygments==2.18.0
pynndescent==0.5.12
pyparsing==3.1.2
pyro-api==0.1.2
pyro-ppl==1.9.1
python-dateutil==2.9.0.post0
pytorch-lightning==1.5.10
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.0.3
rich==13.7.1
scanpy==1.10.1
scikit-learn==1.5.0
scipy==1.13.1
scvi-tools==1.1.2
seaborn==0.13.2
session_info==1.0.0
six==1.16.0
stack-data==0.6.3
statsmodels==0.14.2
stdlib-list==0.10.0
sympy==1.12.1
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorstore==0.1.60
threadpoolctl==3.5.0
toolz==0.12.1
torch==2.3.1
torchmetrics==1.4.0.post0
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
typing_extensions==4.12.1
tzdata==2024.1
umap-learn==0.5.6
wcwidth==0.2.13
Werkzeug==3.0.3
widgetsnbextension==4.0.11
yarl==1.9.4
zipp==3.19.2
Thanks in advance, for the help