Hi, thanks for the pointer. We tried it beforehand and got Nan errors (that was in the not released torch version). I looked deeper into it, it’s a very strange one. Apparently MPS handles broadcast_all weirdly (and the line torch.lgamma(theta)
is producing inf values). This is an example of the behavior:
b = torch.full((5, 3), 1., device='mps')
c = torch.full((5, 1), 1., device='mps')
b, c = broadcast_all(b, c)
torch.lgamma(b), torch.lgamma(c)
(tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], device='mps:0'),
tensor([[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0906],
[-0.0906, -0.0906, -0.0906],
[-0.0906, -0.0906, -0.0906],
[-0.0906, -0.0906, -0.0906]], device='mps:0'))
While the actual issue is something the MPS-torch team has to look into (I guess it’s some issue with pointing in memory). However, this also points to two manners in which scVI works on MPS. First one: gene_likelihood=‘Poisson’ is fully supported. Second one, if using dispersion=‘gene-label’ during setting up the model, the broadcast_all function has no effect and I’m not getting None errors (if not using label_key using this setup has no effect on the training procedure and is save - it won’t work for scANVI though). The speedup that I’m getting on an M1 Max on High Power mode is 80% with very high batch size - this allows testing the GPU capabilities). I would be curious to get more people to benchmark it (especially using M3 chips).