Using low precision matrix multiplication to boost performance

Hey,

I was wondering if we could use Speed Up Model Training — PyTorch Lightning 2.4.0 documentation for GPUs to speed up training. Do we need to modify the training code of scVI in order to use that or should it be usable out of the box?

Thanks

Hi, for integration my experience is that torch.set_float32_matmul_precision("medium") is safe to use. For downstream analysis like differential gene expression this might not be the case as these are more sensitive. It is unfortunately very hard to test in a manner to give good guarantees about it.

Thanks. The speedup in integration is significant (at least 2x) by setting torch.set_float32_matmul_precision("medium"). Can we suggest to use this in scvi docs somewhere when you have millions of cells and want to speed things up?

No, recommending it would require more extensive testing, and I’m not sure how similar we would expect things to be. To get, major speed-up, set batch_size to 1024/2048 (pretty linear in speed).
We will have torch.compile support soon’ish (another 50% speed-up)

How do I track the training so that I know that changing to low precision did/didn’t work? What would you look for?

Also in Increase scVI integration speed - #2 by Valentine_Svensson @Valentine_Svensson didn’t see any significant increase by changing batch_size. Is it still the case?

For things like DE, it’s pretty hard to tell as two very similar reconstruction losses might still yield different results (as it’s not very sensitive for e.g. low rates (whether it’s 1e-5 or 1e-6 yields almost the same loss but is different for DE). We are working on making downstream analysis more robust.
For integration quality, just check that the embedding makes sense (your cell-types are well separated and your batches mix well). We tested large batch sizes for CELLXGENE census (30 million cells) and found that increasing batch size was helpful for speed. Also when tracking losses. The effect might be smaller on other datasets but I still think it’s now the default for many labs doing large integration.