Hi there,
I’m trying to build a custom scvi-tools model using the Pyro backend that reads information from a MuData object, including some data from a MuData .varm
field. I’ve extended the BaseModelClass with PyroSviTrainMixin and things were going smoothly until I tried to register a custom MuDataVarmField.
The field gets added to the registry just fine, but when I try to train my module using the default train()
function provided by PyroSviTrainMixin, the code breaks because (apparently) it tries to provide minibatch indices to the .varm
field I’ve registered. For example:
IndexError: index 46328 is out of bounds for axis 0 with size 30
Is there any way to tell the model to not minibatch this field (which is 30 x 10 for this example) during training? Or if not, could someone point me to which class/method I’d need to override to add this functionality? I’m pretty new to using scvi-tools as well as the Lightning data loader stuff so hoping that someone with more experience might have a quick answer to this