Topic modelling or factorization on integrated scRNA-seq data?

Does anyone have any thoughts or advice on approaching topic modelling or factorization on scVI-integrated data? I have some collections of scRNA-seq data from various sources and scVI does a great job at producing a batch-corrected representation. Within some of the cell types, there’s more of a continuous gradation of expression. If it were a single dataset, I’d try something like NMF to look at these programs, but in this case I’d be concerned about batch-specific factors.

Is there a reasonably clean way to approach this? Eg. NMF on transform_batch-corrected counts? In papers, it’s not uncommon to see something like NMF on each sample independently and then look for similar factors, but this feels a little clunky.

I appreciate any advice!