Macbook M1 M2 mps acceleration with scVI

Has anyone recently gotten scVI (ideally 1.0.4) working with “GPU” (well, “mps”) acceleration with a Apple ARM M1, M2, or M3? I’ve tried a variety of incantations when installing torch and jax and it either doesn’t see the GPU or does and throws a tensor error which suggests something is very borked somewhere in the software chain.

ValueError: Expected parameter loc (Tensor of shape (128, 30)) of distribution Normal(loc: torch.Size([128, 30]), scale: torch.Size([128, 30])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='mps:0',
       grad_fn=<LinearBackward0>)

Same issue here. I am on a Mac M3

I had similar issues and found this:

I think this is a pytorch issue, not an scVI issue.

I think this is the pytorch issue where they track mps compatibility:

I think the specific function that’s incompatible (at least for my usage) was

aten::_standard_gamma

I’ve tried testing out the nightly PyTorch versions with the MPS backend and have had no success. Technically it should work since they’ve implemented the lgamma kernel, which was the last one needed to fully support running scVI, but it looks like there might be issues with the implementation or numerical instabilities since I’ve also experienced NaNs in the first epoch of training.

So yes, this is an issue on the PyTorch end - unfortunately there’s not much we can do to support the Metal backend.

Thanks - I’ll keep track of this thread and I hope if anyone gets scVI working on a nightly (or better, stable) branch of pytorch they will report it here!

I ran into the same issue with torch 2.2.2 on mps backend where torch.lgamma produces -inf
very confusing to debug.
Thanks for tracking it here!

Hi, thanks for the pointer. We tried it beforehand and got Nan errors (that was in the not released torch version). I looked deeper into it, it’s a very strange one. Apparently MPS handles broadcast_all weirdly (and the line torch.lgamma(theta) is producing inf values). This is an example of the behavior:

b = torch.full((5, 3), 1., device='mps')
c = torch.full((5, 1), 1., device='mps')
b, c = broadcast_all(b, c)
torch.lgamma(b), torch.lgamma(c)
(tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]], device='mps:0'),
 tensor([[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0906],
         [-0.0906, -0.0906, -0.0906],
         [-0.0906, -0.0906, -0.0906],
         [-0.0906, -0.0906, -0.0906]], device='mps:0'))

While the actual issue is something the MPS-torch team has to look into (I guess it’s some issue with pointing in memory). However, this also points to two manners in which scVI works on MPS. First one: gene_likelihood=‘Poisson’ is fully supported. Second one, if using dispersion=‘gene-label’ during setting up the model, the broadcast_all function has no effect and I’m not getting None errors (if not using label_key using this setup has no effect on the training procedure and is save - it won’t work for scANVI though). The speedup that I’m getting on an M1 Max on High Power mode is 80% with very high batch size - this allows testing the GPU capabilities). I would be curious to get more people to benchmark it (especially using M3 chips).

I would conclude that the use of MPS even with the proposed change should be only experimental, as we use broadcasting at various locations and different models and there is an issue in pytorch MPS with this: MPS lgamma function changes results when using broadcasting · Issue #132605 · pytorch/pytorch · GitHub