Macbook M1 M2 mps acceleration with scVI

Has anyone recently gotten scVI (ideally 1.0.4) working with “GPU” (well, “mps”) acceleration with a Apple ARM M1, M2, or M3? I’ve tried a variety of incantations when installing torch and jax and it either doesn’t see the GPU or does and throws a tensor error which suggests something is very borked somewhere in the software chain.

ValueError: Expected parameter loc (Tensor of shape (128, 30)) of distribution Normal(loc: torch.Size([128, 30]), scale: torch.Size([128, 30])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='mps:0',
       grad_fn=<LinearBackward0>)

Same issue here. I am on a Mac M3

I had similar issues and found this:

I think this is a pytorch issue, not an scVI issue.

I think this is the pytorch issue where they track mps compatibility:

I think the specific function that’s incompatible (at least for my usage) was

aten::_standard_gamma

I’ve tried testing out the nightly PyTorch versions with the MPS backend and have had no success. Technically it should work since they’ve implemented the lgamma kernel, which was the last one needed to fully support running scVI, but it looks like there might be issues with the implementation or numerical instabilities since I’ve also experienced NaNs in the first epoch of training.

So yes, this is an issue on the PyTorch end - unfortunately there’s not much we can do to support the Metal backend.

Thanks - I’ll keep track of this thread and I hope if anyone gets scVI working on a nightly (or better, stable) branch of pytorch they will report it here!