MRVI Jax Error "Invalid component distribution"

Hello! I am trying to run MRVI on some scRNAseq data and ran into an error (listed below) during model.train(). I initially ran into this error when using jax version 0.9.2 and I saw in the first line of output from model.train()UserWarning: Running mrVI with Jax version larger 0.4.35 can cause performance issues”, so I tried switching jax version 0.4.35, but that caused a different problem where numpyro could no longer be loaded (see error below). Does anyone know what to do to solve these errors?

Original Error:
AssertionError: Invalid component distribution: Independent. The mixture components must have a support that does not depend on their parameters (expected ParameterFreeConstraint, but found IndependentConstraint(Real(), 1)).

Error after trying to switch versions:

from model = MRVI(adata)ModuleNotFoundError: Please install [‘numpyro’] to use this functionality

from import numpyroImportError: cannot import name ‘debug_info’ from ‘jax.api_util’

Version Info:

python v3.12.12, scvi v1.4.0.post1, jax v0.9.2 and v0.4.35, OS is “Rocky-Linux-8”

Can you move on to running it with the “torch” backend?

We are dropping the support of Jax backend for MRVI , and the rest of the models, because of those reasons. Running with PyTorch should give the same performance, faster.

The 4.3.5 warning is old. We will drop it as well.

1 Like

Wow that worked, thank you so much! I used the scvi.external.TorchMRVI class instead of scvi.external.MRVI and it worked like a charm.

Oh I see and setting the backend parameter to “torch” in the normal scvi.external.MRVI class works too.