Error training MrVI: Backend 'cuda' failed to initialize

Hi,
I am eager to try out the new MrVI (scVI tools version 1.2.0) on my data, but I am running into the following error.

Here is how I set up the model:

MRVI.setup_anndata(adata_immune, sample_key= 'sample', batch_key = "version_10x", layer = 'counts')

model = MRVI(adata_immune)

The error occurs when training the model.

model.train(max_epochs=400)
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[43], line 1
----> 1 model.train(max_epochs=400) #, accelerator='gpu', devices=1)
      3 # default accelerator='auto'
      4 # note: model takes time to run -- run overnight with accelerator='cpu' and only got to 88% 

File ~/.local/lib/python3.10/site-packages/scvi/external/mrvi/_model.py:252, in MRVI.train(self, max_epochs, accelerator, devices, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
    248 plan_kwargs = plan_kwargs or {}
    249 train_kwargs["plan_kwargs"] = dict(
    250     deepcopy(DEFAULT_TRAIN_KWARGS["plan_kwargs"]), **plan_kwargs
    251 )
--> 252 super().train(**train_kwargs)

File ~/.local/lib/python3.10/site-packages/scvi/model/base/_jaxmixin.py:68, in JaxTrainingMixin.train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, datasplitter_kwargs, plan_kwargs, **trainer_kwargs)
     65 if max_epochs is None:
     66     max_epochs = get_max_epochs_heuristic(self.adata.n_obs)
---> 68 _, _, device = parse_device_args(
     69     accelerator,
     70     devices,
     71     return_device="jax",
     72     validate_single_device=True,
     73 )
     74 try:
     75     self.module.to(device)

File ~/.local/lib/python3.10/site-packages/scvi/model/_utils.py:145, in parse_device_args(accelerator, devices, return_device, validate_single_device)
    143     device = jax.devices("cpu")[0]
    144     if _accelerator != "cpu":
--> 145         device = jax.devices(_accelerator)[device_idx]
    146     return _accelerator, _devices, device
    148 return _accelerator, _devices

File ~/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1077, in devices(backend)
   1052 def devices(
   1053     backend: str | xla_client.Client | None = None
   1054 ) -> list[xla_client.Device]:
   1055   """Returns a list of all devices for a given backend.
   1056 
   1057   .. currentmodule:: jaxlib.xla_extension
   (...)
   1075     List of Device subclasses.
   1076   """
-> 1077   return get_backend(backend).devices()

File ~/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1011, in get_backend(platform)
   1007 @lru_cache(maxsize=None)  # don't use util.memoize because there is no X64 dependence.
   1008 def get_backend(
   1009     platform: None | str | xla_client.Client = None
   1010 ) -> xla_client.Client:
-> 1011   return _get_backend_uncached(platform)

File ~/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:996, in _get_backend_uncached(platform)
    994 if backend is None:
    995   if platform in _backend_errors:
--> 996     raise RuntimeError(f"Backend '{platform}' failed to initialize: "
    997                        f"{_backend_errors[platform]}. "
    998                        f'Available backends are {list(bs)}')
    999   raise RuntimeError(
   1000       f"Unknown backend {platform}. Available backends are {list(bs)}")
   1001 return backend

RuntimeError: Backend 'cuda' failed to initialize: . Available backends are ['cpu']

I don’t know why I am getting the error “Backend ‘cuda’ failed to initialize”. I already confirmed that GPU is available in my environment, and I have pytorch version 12.4.

Here is some additional information:

!nvcc --version
/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_May__3_19:15:13_PDT_2021
Cuda compilation tools, release 11.3, V11.3.109
Build cuda_11.3.r11.3/compiler.29920130_0
import torch
torch.cuda.is_available()
True
print(torch.version.cuda)
12.4
torch.cuda.current_device()
0
torch.cuda.device_count()
2
torch.cuda.get_device_name(0)
'Tesla T4'

Thank you for your help!
Kate

Hi, thanks for trying out MrVI. It uses Jax and not PyTorch as the GPU platform. Please install a GPU-enabled version of Jax. Installation — JAX documentation

Hi,
Thank you! Installing Jax allowed the model to run.