AttributeError: module 'jax' has no attribute 'Device'

Hey everyone,

When I install scVI as described in your documentation, I get an error following import. I attached a screenshot of the error log. Thanks in advance for any help you can afford!

Hi, could you run pip install --upgrade jax jaxlib chex?

Hi Martin,

Thanks for getting back to me! I will try this when I get home tonight!