Problems importing scvi [Windows]

Hi I imported through conda using:
conda create -n scvi-env python=3.9

when importing in jupyter I use
import scvi

I get an import error with pytorch_lightning:

ImportError: cannot import name ‘get_num_classes’ from ‘torchmetrics.utilities.data’

Full error message:

ImportError Traceback (most recent call last)
Cell In [3], line 1
----> 1 import scvi

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\scvi_init_.py:7
4 import logging
6 from ._constants import _CONSTANTS
----> 7 from ._settings import settings
9 # this import needs to come after prior imports to prevent circular import
10 from . import data, model, external, utils

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\scvi_settings.py:5
2 from pathlib import Path
3 from typing import Union
----> 5 import pytorch_lightning as pl
6 import torch
7 from rich.console import Console

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning_init_.py:20
17 _PACKAGE_ROOT = os.path.dirname(file)
18 _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
—> 20 from pytorch_lightning import metrics # noqa: E402
21 from pytorch_lightning.callbacks import Callback # noqa: E402
22 from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics_init_.py:15
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 15 from pytorch_lightning.metrics.classification import ( # noqa: F401
16 Accuracy,
17 AUC,
18 AUROC,
19 AveragePrecision,
20 ConfusionMatrix,
21 F1,
22 FBeta,
23 HammingDistance,
24 IoU,
25 Precision,
26 PrecisionRecallCurve,
27 Recall,
28 ROC,
29 StatScores,
30 )
31 from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
32 from pytorch_lightning.metrics.regression import ( # noqa: F401
33 ExplainedVariance,
34 MeanAbsoluteError,
(…)
39 SSIM,
40 )

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics\classification_init_.py:14
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 14 from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401
15 from pytorch_lightning.metrics.classification.auc import AUC # noqa: F401
16 from pytorch_lightning.metrics.classification.auroc import AUROC # noqa: F401

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics\classification\accuracy.py:18
14 from typing import Any, Callable, Optional
16 from torchmetrics import Accuracy as _Accuracy
—> 18 from pytorch_lightning.metrics.utils import deprecated_metrics, void
21 class Accuracy(_Accuracy):
22 @deprecated_metrics(target=_Accuracy)
23 def init(
24 self,
(…)
31 dist_sync_fn: Callable = None,
32 ):

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_lightning\metrics\utils.py:22
20 from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean
21 from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum
—> 22 from torchmetrics.utilities.data import get_num_classes as _get_num_classes
23 from torchmetrics.utilities.data import select_topk as _select_topk
24 from torchmetrics.utilities.data import to_categorical as _to_categorical

ImportError: cannot import name ‘get_num_classes’ from ‘torchmetrics.utilities.data’

Based on the paths of the traceback, it’s not clear that the conda environment you created is actually being used.

It also seems like you’re using windows. Please see our prerequisites for using windows.

Thank you so much for your help.

When running :
pip install “jax[cpu]” -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

I got this error output:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.3.24 (from jax[cpu]) (from versions: 0.1.70, 0.1.70+cuda101, 0.1.70+cuda111, 0.1.71, 0.1.71+cuda101, 0.1.71+cuda111, 0.1.72, 0.1.72+cuda111, 0.1.73, 0.1.75, 0.1.75+cuda11.cudnn82, 0.1.76, 0.1.76+cuda11.cudnn82, 0.3.0, 0.3.0+cuda11.cudnn82, 0.3.2, 0.3.2+cuda11.cudnn82, 0.3.5, 0.3.5+cuda11.cudnn82, 0.3.7, 0.3.7+cuda11.cudnn82, 0.3.11, 0.3.11+cuda11.cudnn82, 0.3.14, 0.3.14+cuda11.cudnn82, 0.3.17, 0.3.17+cuda11.cudnn82, 0.3.20, 0.3.20+cuda11.cudnn82, 0.3.22, 0.3.22+cuda11.cudnn82)
ERROR: No matching distribution found for jaxlib==0.3.24 (from jax[cpu])

I ran with jaxlib==0.3.22 but still cannot import scvi.
Is there a different jaxlib version I should use?

Was able to install jax and jax lib with:
pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html

In notebook :

from platform import python_version
print(python_version())

Output : 3.9.15

import os
print (os.environ['CONDA_DEFAULT_ENV'])

Output : scvi-env

import scvi

Output:

ImportError Traceback (most recent call last)
Cell In [3], line 1
----> 1 import scvi

File ~.conda\envs\scvi-env\lib\site-packages\scvi_init_.py:7
4 import logging
6 from ._constants import _CONSTANTS
----> 7 from ._settings import settings
9 # this import needs to come after prior imports to prevent circular import
10 from . import data, model, external, utils

File ~.conda\envs\scvi-env\lib\site-packages\scvi_settings.py:5
2 from pathlib import Path
3 from typing import Union
----> 5 import pytorch_lightning as pl
6 import torch
7 from rich.console import Console

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning_init_.py:20
17 _PACKAGE_ROOT = os.path.dirname(file)
18 _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
—> 20 from pytorch_lightning import metrics # noqa: E402
21 from pytorch_lightning.callbacks import Callback # noqa: E402
22 from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics_init_.py:15
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 15 from pytorch_lightning.metrics.classification import ( # noqa: F401
16 Accuracy,
17 AUC,
18 AUROC,
19 AveragePrecision,
20 ConfusionMatrix,
21 F1,
22 FBeta,
23 HammingDistance,
24 IoU,
25 Precision,
26 PrecisionRecallCurve,
27 Recall,
28 ROC,
29 StatScores,
30 )
31 from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
32 from pytorch_lightning.metrics.regression import ( # noqa: F401
33 ExplainedVariance,
34 MeanAbsoluteError,
(…)
39 SSIM,
40 )

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics\classification_init_.py:14
1 # Copyright The PyTorch Lightning team.
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
(…)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
—> 14 from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401
15 from pytorch_lightning.metrics.classification.auc import AUC # noqa: F401
16 from pytorch_lightning.metrics.classification.auroc import AUROC # noqa: F401

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics\classification\accuracy.py:18
14 from typing import Any, Callable, Optional
16 from torchmetrics import Accuracy as _Accuracy
—> 18 from pytorch_lightning.metrics.utils import deprecated_metrics
21 class Accuracy(_Accuracy):
23 @deprecated_metrics(target=_Accuracy)
24 def init(
25 self,
(…)
32 dist_sync_fn: Callable = None,
33 ):

File ~.conda\envs\scvi-env\lib\site-packages\pytorch_lightning\metrics\utils.py:22
20 from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean
21 from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum
—> 22 from torchmetrics.utilities.data import get_num_classes as _get_num_classes
23 from torchmetrics.utilities.data import select_topk as _select_topk
24 from torchmetrics.utilities.data import to_categorical as _to_categorical

ImportError: cannot import name ‘get_num_classes’ from ‘torchmetrics.utilities.data’ (C:\Users\khakhlab.conda\envs\scvi-env\lib\site-packages\torchmetrics\utilities\data.py)

Thank you again for your time and help

I have the exact same problem on Windows, can you recommend any solution?

Can you try

pip install "jax[cpu]===0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

in a fresh virtual environment and then installing scvi-tools?

Hi, I did what you said and unfortunately i ran into the following error:

RuntimeError                              Traceback (most recent call last)
Cell In [3], line 1
----> 1 import scvi

File c:\Users\Bruno\Enviroments\newSars_env\lib\site-packages\scvi\__init__.py:10
      7 from ._settings import settings
      9 # this import needs to come after prior imports to prevent circular import
---> 10 from . import data, model, external, utils
     12 # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
     13 # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
     14 try:

File c:\Users\Bruno\Enviroments\newSars_env\lib\site-packages\scvi\data\__init__.py:3
      1 from anndata import read_csv, read_h5ad, read_loom, read_text
----> 3 from ._datasets import (
      4     annotation_simulation,
      5     brainlarge_dataset,
      6     breast_cancer_dataset,
      7     cortex,
      8     dataset_10x,
      9     frontalcortex_dropseq,
     10     heart_cell_atlas_subsampled,
     11     mouse_ob_dataset,
     12     pbmc_dataset,
...
     66   msg = (f'jaxlib version {jaxlib_version} is newer than and '
     67          f'incompatible with jax version {jax_version}. Please '
     68          'update your jax and/or jaxlib packages.')

RuntimeError: jaxlib is version 0.3.14, but this version of jax requires version >= 0.3.22.

Hi I also tried this and ran into the same error as before:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import scvi

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\scvi\__init__.py:7, in <module>
      4 import logging
      6 from ._constants import _CONSTANTS
----> 7 from ._settings import settings
      9 # this import needs to come after prior imports to prevent circular import
     10 from . import data, model, external, utils

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\scvi\_settings.py:5, in <module>
      2 from pathlib import Path
      3 from typing import Union
----> 5 import pytorch_lightning as pl
      6 import torch
      7 from rich.console import Console

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\__init__.py:20, in <module>
     17 _PACKAGE_ROOT = os.path.dirname(__file__)
     18 _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
---> 20 from pytorch_lightning import metrics  # noqa: E402
     21 from pytorch_lightning.callbacks import Callback  # noqa: E402
     22 from pytorch_lightning.core import LightningDataModule, LightningModule  # noqa: E402

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\__init__.py:15, in <module>
      1 # Copyright The PyTorch Lightning team.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from pytorch_lightning.metrics.classification import (  # noqa: F401
     16     Accuracy,
     17     AUC,
     18     AUROC,
     19     AveragePrecision,
     20     ConfusionMatrix,
     21     F1,
     22     FBeta,
     23     HammingDistance,
     24     IoU,
     25     Precision,
     26     PrecisionRecallCurve,
     27     Recall,
     28     ROC,
     29     StatScores,
     30 )
     31 from pytorch_lightning.metrics.metric import Metric, MetricCollection  # noqa: F401
     32 from pytorch_lightning.metrics.regression import (  # noqa: F401
     33     ExplainedVariance,
     34     MeanAbsoluteError,
   (...)
     39     SSIM,
     40 )

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\classification\__init__.py:14, in <module>
      1 # Copyright The PyTorch Lightning team.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 14 from pytorch_lightning.metrics.classification.accuracy import Accuracy  # noqa: F401
     15 from pytorch_lightning.metrics.classification.auc import AUC  # noqa: F401
     16 from pytorch_lightning.metrics.classification.auroc import AUROC  # noqa: F401

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\classification\accuracy.py:18, in <module>
     14 from typing import Any, Callable, Optional
     16 from torchmetrics import Accuracy as _Accuracy
---> 18 from pytorch_lightning.metrics.utils import deprecated_metrics
     21 class Accuracy(_Accuracy):
     23     @deprecated_metrics(target=_Accuracy)
     24     def __init__(
     25         self,
   (...)
     32         dist_sync_fn: Callable = None,
     33     ):

File ~\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\pytorch_lightning\metrics\utils.py:22, in <module>
     20 from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean
     21 from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum
---> 22 from torchmetrics.utilities.data import get_num_classes as _get_num_classes
     23 from torchmetrics.utilities.data import select_topk as _select_topk
     24 from torchmetrics.utilities.data import to_categorical as _to_categorical

ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' (C:\Users\klinker\AppData\Local\R-MINI~1\envs\test-scvi\lib\site-packages\torchmetrics\utilities\data.py)

Blockquote

Can you try the following (for windows)

conda create -n scvi-tools-env python=3.9 pip
pip install "jax[cpu]===0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install scvi-tools
1 Like

In my case, It had worked.
I think the problem was that I have install pytorch before installing scvi

Thank you very much - this worked for me too when I used

pip install "jax[cpu]===0.3.22" -f https://whls.blob.core.windows.net/unstable/index.html *--use-deprecated legacy-resolver*

I had the same error on a linux machine, reinstalled with python=3.9 instead of 3.7 and conda-forge::scvi-tools instead of bioconda::scvi-tools, that somehow made the error disappear

Indeed we have not been releasing via bioconda, and instead require conda-forge if using conda. Most packages are dropping python 3.7 support as well so anyone reading this should update to python 3.10 (ideally) if starting fresh!

It has been a few months since this thread was active, but I had the same problem importing scvi (windows) and found this discussion.
Following the instructions precisely, installation of scvi in my fresh venv seems to go smoothly, but when I attempt to import I still get an error.
RuntimeError: jaxlib is version 0.3.14, but this version of jax requires version >=0.4.7.
When I modify the “jax[cpu]===0.3.22” to “jax[cpu]===0.4.7” it cannot find it at the url.
Do you have any advice adam? Much appreciated.

The best advice I have for windows at the moment, while likely unsatisfying, is to use the windows subsystem for linux which should solve all of these problems.

OK, thank you. That is kind of what I was thinking…

I am also having the same problem. It would be really great if it could be backward compatible with versions of jax that are available on Windows. Is this possible?

Hi, sorry for the late reply. We don’t plan on scvi-tools being backwards compatible with jax 0.3.x due to some recent patches, apologies for the inconvenience. I believe they are working on releasing community-built jax installations for 0.4.x – you can follow that here. Otherwise, we’d recommend using the windows subsystem for linux as mentioned before, or installing a previous version of scvi-tools (this would be anything <0.20.1).

Should be fixed now with the jax windows build here:

Thank you . After facing difficulties installing scvi-tools on Windows, I’ve finally managed to get it up and running by following the provided solution and also installing scanpy. It’s operational now, but I’m currently unable to utilize the GPU.