Hello! I am trying to run MRVI on some scRNAseq data and ran into an error (listed below) during model.train(). I initially ran into this error when using jax version 0.9.2 and I saw in the first line of output from model.train() “UserWarning: Running mrVI with Jax version larger 0.4.35 can cause performance issues”, so I tried switching jax version 0.4.35, but that caused a different problem where numpyro could no longer be loaded (see error below). Does anyone know what to do to solve these errors?
Original Error:
AssertionError: Invalid component distribution: Independent. The mixture components must have a support that does not depend on their parameters (expected ParameterFreeConstraint, but found IndependentConstraint(Real(), 1)).
Error after trying to switch versions:
from model = MRVI(adata) → ModuleNotFoundError: Please install [‘numpyro’] to use this functionality
from import numpyro → ImportError: cannot import name ‘debug_info’ from ‘jax.api_util’
Version Info:
python v3.12.12, scvi v1.4.0.post1, jax v0.9.2 and v0.4.35, OS is “Rocky-Linux-8”