AttributeError: module 'jax' has no attribute 'Device'

Hey everyone,

When I install scVI as described in your documentation, I get an error following import. I attached a screenshot of the error log. Thanks in advance for any help you can afford!

Hi, could you run pip install --upgrade jax jaxlib chex?

1 Like

Hi Martin,

Thanks for getting back to me! I will try this when I get home tonight!

I encountered the same issue after upgrading scvi-tools from 0.16 to 0.20.3. Your suggestion worked nicely. Many thanks!


This solution no longer works as it upgrades Jax and Jaxlib to versions that no longer have ShapedArray - resulting in the following error.

ImportError: cannot import name ‘ShapedArray’ from ‘jax’

Hi, what version of scvi-tools do you have installed?




Could you please update to the latest stable release (1.0.4) and check if the issue persists?


That fixed it. It required both an upgrade to the latest version of scvi-tools as well as flax.


Hi all,
I recently upgraded my scvi-tools to the 1.0.4 release as described above as I was experiencing the same errors (The errors only went away when I resorted to making a whole new conda environment with: conda create --name new_scvi_tools scvi-tools==1.0.4 --channel=conda-forge), but now when I run the code:

weights = ‘importance’ # This is from the 2023 paper describing lvm-DE is better
#also this is log2 fold change returned

It takes forever: as of writing this it has been over 12 hours of compute. Previously this used to take only about a minute (weights=‘uniform’#the default) as I only have ~7k cells (All genes however). Is this because I recently upgraded my scvi-tools to 1.0.4 and my model was computed with an older version of scvi-tools? Is it because older versions didn’t add the importance weights to the model and it is going back and calculating them now? The progress bar is pinned at 0 and my cpu has been maxed out, so it seems to be running still. There are no error messages, but also no progress logging messages.