I’ve tried testing out the nightly PyTorch versions with the MPS backend and have had no success. Technically it should work since they’ve implemented the lgamma
kernel, which was the last one needed to fully support running scVI, but it looks like there might be issues with the implementation or numerical instabilities since I’ve also experienced NaNs in the first epoch of training.
So yes, this is an issue on the PyTorch end - unfortunately there’s not much we can do to support the Metal backend.