Recently PyTorch has announced ARM Mac GPU support. However, I don’t know how to adopt this feature in scvi.model.SCVI.train. By my own search, it seems like it relates to how to specify the correct device where in default device “CPU” is set, while device “MPS” is the right choice for using this new feature.
Does anyone know how to proceed?
We’re waiting on support through Pytorch Lightning which is the API we use for our training loops. They have an open issue here for tracking: MPS (Mac M1) device support · Issue #13102 · Lightning-AI/lightning · GitHub. However, my hunch is that for most of our models the speedup may be negligible just because these optimizations seem to only make a significant difference for larger models. Please check out our JaxSCVI model and stay tuned for a more generalized Jax interface which we found to have 2x-4x speedups on both CPU and GPU!
Thank you so much for your response.
I am looking forward to the upgrade which supports the ARM GPU in model training.
Moreover, I notice that JaxSCVI seems to be an experimental feature. I would try it once it’s officially released.