Problems importing scvi [Windows]

Hi I imported through conda using:
conda create -n scvi-env python=3.9

when importing in jupyter I use
import scvi

I get an import error with pytorch_lightning:

ImportError: cannot import name ‘get_num_classes’ from ‘torchmetrics.utilities.data’

Full error message:

ImportError Traceback (most recent call last)
Cell In [3], line 1
----> 1 import scvi

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\scvi_init_.py:7
4 import logging
6 from ._constants import _CONSTANTS
----> 7 from ._settings import settings
9 # this import needs to come after prior imports to prevent circular import
10 from . import data, model, external, utils

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\scvi_settings.py:5
2 from pathlib import Path
3 from typing import Union
----> 5 import pytorch_lightning as pl
6 import torch
7 from rich.console import Console

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning_init_.py:20
17 _PACKAGE_ROOT = os.path.dirname(file)
18 _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
—> 20 from pytorch_lightning import metrics # noqa: E402
21 from pytorch_lightning.callbacks import Callback # noqa: E402
22 from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics_init_.py:15
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 15 from pytorch_lightning.metrics.classification import ( # noqa: F401
16 Accuracy,
17 AUC,
18 AUROC,
19 AveragePrecision,
20 ConfusionMatrix,
21 F1,
22 FBeta,
23 HammingDistance,
24 IoU,
25 Precision,
26 PrecisionRecallCurve,
27 Recall,
28 ROC,
29 StatScores,
30 )
31 from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
32 from pytorch_lightning.metrics.regression import ( # noqa: F401
33 ExplainedVariance,
34 MeanAbsoluteError,
(…)
39 SSIM,
40 )

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics\classification_init_.py:14
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 14 from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401
15 from pytorch_lightning.metrics.classification.auc import AUC # noqa: F401
16 from pytorch_lightning.metrics.classification.auroc import AUROC # noqa: F401

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics\classification\accuracy.py:18
14 from typing import Any, Callable, Optional
16 from torchmetrics import Accuracy as _Accuracy
—> 18 from pytorch_lightning.metrics.utils import deprecated_metrics, void
21 class Accuracy(_Accuracy):
22 @deprecated_metrics(target=_Accuracy)
23 def init(
24 self,
(…)
31 dist_sync_fn: Callable = None,
32 ):

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics\utils.py:22
20 from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean
21 from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum
—> 22 from torchmetrics.utilities.data import get_num_classes as _get_num_classes
23 from torchmetrics.utilities.data import select_topk as _select_topk
24 from torchmetrics.utilities.data import to_categorical as _to_categorical

ImportError: cannot import name ‘get_num_classes’ from ‘torchmetrics.utilities.data’

Based on the paths of the traceback, it’s not clear that the conda environment you created is actually being used.

It also seems like you’re using windows. Please see our prerequisites for using windows.

Thank you so much for your help.

When running :
pip install “jax[cpu]” -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

I got this error output:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.3.24 (from jax[cpu]) (from versions: 0.1.70, 0.1.70+cuda101, 0.1.70+cuda111, 0.1.71, 0.1.71+cuda101, 0.1.71+cuda111, 0.1.72, 0.1.72+cuda111, 0.1.73, 0.1.75, 0.1.75+cuda11.cudnn82, 0.1.76, 0.1.76+cuda11.cudnn82, 0.3.0, 0.3.0+cuda11.cudnn82, 0.3.2, 0.3.2+cuda11.cudnn82, 0.3.5, 0.3.5+cuda11.cudnn82, 0.3.7, 0.3.7+cuda11.cudnn82, 0.3.11, 0.3.11+cuda11.cudnn82, 0.3.14, 0.3.14+cuda11.cudnn82, 0.3.17, 0.3.17+cuda11.cudnn82, 0.3.20, 0.3.20+cuda11.cudnn82, 0.3.22, 0.3.22+cuda11.cudnn82)
ERROR: No matching distribution found for jaxlib==0.3.24 (from jax[cpu])

I ran with jaxlib==0.3.22 but still cannot import scvi.
Is there a different jaxlib version I should use?

Was able to install jax and jax lib with:
pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html

In notebook :

from platform import python_version
print(python_version())

Output : 3.9.15

import os
print (os.environ['CONDA_DEFAULT_ENV'])

Output : scvi-env

import scvi

Output:

ImportError Traceback (most recent call last)
Cell In [3], line 1
----> 1 import scvi

File ~.conda\envs\scvi-env\lib\site-packages\scvi_init_.py:7
4 import logging
6 from ._constants import _CONSTANTS
----> 7 from ._settings import settings
9 # this import needs to come after prior imports to prevent circular import
10 from . import data, model, external, utils

File ~.conda\envs\scvi-env\lib\site-packages\scvi_settings.py:5
2 from pathlib import Path
3 from typing import Union
----> 5 import pytorch_lightning as pl
6 import torch
7 from rich.console import Console

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning_init_.py:20
17 _PACKAGE_ROOT = os.path.dirname(file)
18 _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
—> 20 from pytorch_lightning import metrics # noqa: E402
21 from pytorch_lightning.callbacks import Callback # noqa: E402
22 from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics_init_.py:15
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 15 from pytorch_lightning.metrics.classification import ( # noqa: F401
16 Accuracy,
17 AUC,
18 AUROC,
19 AveragePrecision,
20 ConfusionMatrix,
21 F1,
22 FBeta,
23 HammingDistance,
24 IoU,
25 Precision,
26 PrecisionRecallCurve,
27 Recall,
28 ROC,
29 StatScores,
30 )
31 from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
32 from pytorch_lightning.metrics.regression import ( # noqa: F401
33 ExplainedVariance,
34 MeanAbsoluteError,
(…)
39 SSIM,
40 )

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics\classification_init_.py:14
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 14 from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401
15 from pytorch_lightning.metrics.classification.auc import AUC # noqa: F401
16 from pytorch_lightning.metrics.classification.auroc import AUROC # noqa: F401

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics\classification\accuracy.py:18
14 from typing import Any, Callable, Optional
16 from torchmetrics import Accuracy as _Accuracy
—> 18 from pytorch_lightning.metrics.utils import deprecated_metrics
21 class Accuracy(_Accuracy):
23 @deprecated_metrics(target=_Accuracy)
24 def init(
25 self,
(…)
32 dist_sync_fn: Callable = None,
33 ):

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics\utils.py:22
20 from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean
21 from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum
—> 22 from torchmetrics.utilities.data import get_num_classes as _get_num_classes
23 from torchmetrics.utilities.data import select_topk as _select_topk
24 from torchmetrics.utilities.data import to_categorical as _to_categorical

ImportError: cannot import name ‘get_num_classes’ from ‘torchmetrics.utilities.data’ (C:\Users\khakhlab.conda\envs\scvi-env\lib\site-packages\torchmetrics\utilities\data.py)

Thank you again for your time and help

I have the exact same problem on Windows, can you recommend any solution?

Can you try

pip install "jax[cpu]===0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

in a fresh virtual environment and then installing scvi-tools?

Hi, I did what you said and unfortunately i ran into the following error:

RuntimeError                              Traceback (most recent call last)
Cell In [3], line 1
----> 1 import scvi

File c:\Users\Bruno\Enviroments\newSars_env\lib\site-packages\scvi\__init__.py:10
      7 from ._settings import settings
      9 # this import needs to come after prior imports to prevent circular import
---> 10 from . import data, model, external, utils
     12 # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
     13 # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
     14 try:

File c:\Users\Bruno\Enviroments\newSars_env\lib\site-packages\scvi\data\__init__.py:3
      1 from anndata import read_csv, read_h5ad, read_loom, read_text
----> 3 from ._datasets import (
      4     annotation_simulation,
      5     brainlarge_dataset,
      6     breast_cancer_dataset,
      7     cortex,
      8     dataset_10x,
      9     frontalcortex_dropseq,
     10     heart_cell_atlas_subsampled,
     11     mouse_ob_dataset,
     12     pbmc_dataset,
...
     66   msg = (f'jaxlib version {jaxlib_version} is newer than and '
     67          f'incompatible with jax version {jax_version}. Please '
     68          'update your jax and/or jaxlib packages.')

RuntimeError: jaxlib is version 0.3.14, but this version of jax requires version >= 0.3.22.

Hi I also tried this and ran into the same error as before:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import scvi

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\scvi\__init__.py:7, in <module>
      4 import logging
      6 from ._constants import _CONSTANTS
----> 7 from ._settings import settings
      9 # this import needs to come after prior imports to prevent circular import
     10 from . import data, model, external, utils

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\scvi\_settings.py:5, in <module>
      2 from pathlib import Path
      3 from typing import Union
----> 5 import pytorch_lightning as pl
      6 import torch
      7 from rich.console import Console

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\__init__.py:20, in <module>
     17 _PACKAGE_ROOT = os.path.dirname(__file__)
     18 _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
---> 20 from pytorch_lightning import metrics  # noqa: E402
     21 from pytorch_lightning.callbacks import Callback  # noqa: E402
     22 from pytorch_lightning.core import LightningDataModule, LightningModule  # noqa: E402

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\__init__.py:15, in <module>
      1 # Copyright The PyTorch Lightning team.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from pytorch_lightning.metrics.classification import (  # noqa: F401
     16     Accuracy,
     17     AUC,
     18     AUROC,
     19     AveragePrecision,
     20     ConfusionMatrix,
     21     F1,
     22     FBeta,
     23     HammingDistance,
     24     IoU,
     25     Precision,
     26     PrecisionRecallCurve,
     27     Recall,
     28     ROC,
     29     StatScores,
     30 )
     31 from pytorch_lightning.metrics.metric import Metric, MetricCollection  # noqa: F401
     32 from pytorch_lightning.metrics.regression import (  # noqa: F401
     33     ExplainedVariance,
     34     MeanAbsoluteError,
   (...)
     39     SSIM,
     40 )

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\classification\__init__.py:14, in <module>
      1 # Copyright The PyTorch Lightning team.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 14 from pytorch_lightning.metrics.classification.accuracy import Accuracy  # noqa: F401
     15 from pytorch_lightning.metrics.classification.auc import AUC  # noqa: F401
     16 from pytorch_lightning.metrics.classification.auroc import AUROC  # noqa: F401

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\classification\accuracy.py:18, in <module>
     14 from typing import Any, Callable, Optional
     16 from torchmetrics import Accuracy as _Accuracy
---> 18 from pytorch_lightning.metrics.utils import deprecated_metrics
     21 class Accuracy(_Accuracy):
     23     @deprecated_metrics(target=_Accuracy)
     24     def __init__(
     25         self,
   (...)
     32         dist_sync_fn: Callable = None,
     33     ):

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\utils.py:22, in <module>
     20 from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean
     21 from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum
---> 22 from torchmetrics.utilities.data import get_num_classes as _get_num_classes
     23 from torchmetrics.utilities.data import select_topk as _select_topk
     24 from torchmetrics.utilities.data import to_categorical as _to_categorical

ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' (C:\Users\klinker\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\torchmetrics\utilities\data.py)

Blockquote

Can you try the following (for windows)

conda create -n scvi-tools-env python=3.9 pip
pip install "jax[cpu]===0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install scvi-tools
1 Like

In my case, It had worked.
I think the problem was that I have install pytorch before installing scvi