Hi,
I am eager to try out the new MrVI (scVI tools version 1.2.0) on my data, but I am running into the following error.
Here is how I set up the model:
MRVI.setup_anndata(adata_immune, sample_key= 'sample', batch_key = "version_10x", layer = 'counts')
model = MRVI(adata_immune)
The error occurs when training the model.
model.train(max_epochs=400)
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[43], line 1
----> 1 model.train(max_epochs=400) #, accelerator='gpu', devices=1)
3 # default accelerator='auto'
4 # note: model takes time to run -- run overnight with accelerator='cpu' and only got to 88%
File ~/.local/lib/python3.10/site-packages/scvi/external/mrvi/_model.py:252, in MRVI.train(self, max_epochs, accelerator, devices, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
248 plan_kwargs = plan_kwargs or {}
249 train_kwargs["plan_kwargs"] = dict(
250 deepcopy(DEFAULT_TRAIN_KWARGS["plan_kwargs"]), **plan_kwargs
251 )
--> 252 super().train(**train_kwargs)
File ~/.local/lib/python3.10/site-packages/scvi/model/base/_jaxmixin.py:68, in JaxTrainingMixin.train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, datasplitter_kwargs, plan_kwargs, **trainer_kwargs)
65 if max_epochs is None:
66 max_epochs = get_max_epochs_heuristic(self.adata.n_obs)
---> 68 _, _, device = parse_device_args(
69 accelerator,
70 devices,
71 return_device="jax",
72 validate_single_device=True,
73 )
74 try:
75 self.module.to(device)
File ~/.local/lib/python3.10/site-packages/scvi/model/_utils.py:145, in parse_device_args(accelerator, devices, return_device, validate_single_device)
143 device = jax.devices("cpu")[0]
144 if _accelerator != "cpu":
--> 145 device = jax.devices(_accelerator)[device_idx]
146 return _accelerator, _devices, device
148 return _accelerator, _devices
File ~/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1077, in devices(backend)
1052 def devices(
1053 backend: str | xla_client.Client | None = None
1054 ) -> list[xla_client.Device]:
1055 """Returns a list of all devices for a given backend.
1056
1057 .. currentmodule:: jaxlib.xla_extension
(...)
1075 List of Device subclasses.
1076 """
-> 1077 return get_backend(backend).devices()
File ~/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1011, in get_backend(platform)
1007 @lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
1008 def get_backend(
1009 platform: None | str | xla_client.Client = None
1010 ) -> xla_client.Client:
-> 1011 return _get_backend_uncached(platform)
File ~/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:996, in _get_backend_uncached(platform)
994 if backend is None:
995 if platform in _backend_errors:
--> 996 raise RuntimeError(f"Backend '{platform}' failed to initialize: "
997 f"{_backend_errors[platform]}. "
998 f'Available backends are {list(bs)}')
999 raise RuntimeError(
1000 f"Unknown backend {platform}. Available backends are {list(bs)}")
1001 return backend
RuntimeError: Backend 'cuda' failed to initialize: . Available backends are ['cpu']
I don’t know why I am getting the error “Backend ‘cuda’ failed to initialize”. I already confirmed that GPU is available in my environment, and I have pytorch version 12.4.
Here is some additional information:
!nvcc --version
/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_May__3_19:15:13_PDT_2021
Cuda compilation tools, release 11.3, V11.3.109
Build cuda_11.3.r11.3/compiler.29920130_0
import torch
torch.cuda.is_available()
True
print(torch.version.cuda)
12.4
torch.cuda.current_device()
0
torch.cuda.device_count()
2
torch.cuda.get_device_name(0)
'Tesla T4'
Thank you for your help!
Kate