Has anyone recently gotten scVI (ideally 1.0.4) working with “GPU” (well, “mps”) acceleration with a Apple ARM M1, M2, or M3? I’ve tried a variety of incantations when installing torch and jax and it either doesn’t see the GPU or does and throws a tensor error which suggests something is very borked somewhere in the software chain.
ValueError: Expected parameter loc (Tensor of shape (128, 30)) of distribution Normal(loc: torch.Size([128, 30]), scale: torch.Size([128, 30])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], device='mps:0',
grad_fn=<LinearBackward0>)