AttributeError: module 'jax' has no attribute 'Device'

Hey everyone,

When I install scVI as described in your documentation, I get an error following import. I attached a screenshot of the error log. Thanks in advance for any help you can afford!

Hi, could you run pip install --upgrade jax jaxlib chex?

1 Like

Hi Martin,

Thanks for getting back to me! I will try this when I get home tonight!

I encountered the same issue after upgrading scvi-tools from 0.16 to 0.20.3. Your suggestion worked nicely. Many thanks!