Mini-batching error after registering MuData .varm field with custom scvi-tools model

Hi there,

I’m trying to build a custom scvi-tools model using the Pyro backend that reads information from a MuData object, including some data from a MuData .varm field. I’ve extended the BaseModelClass with PyroSviTrainMixin and things were going smoothly until I tried to register a custom MuDataVarmField.

The field gets added to the registry just fine, but when I try to train my module using the default train() function provided by PyroSviTrainMixin, the code breaks because (apparently) it tries to provide minibatch indices to the .varm field I’ve registered. For example:

IndexError:  index 46328 is out of bounds for axis 0 with size 30

Is there any way to tell the model to not minibatch this field (which is 30 x 10 for this example) during training? Or if not, could someone point me to which class/method I’d need to override to add this functionality? I’m pretty new to using scvi-tools as well as the Lightning data loader stuff so hoping that someone with more experience might have a quick answer to this :sweat_smile:

Hi, thank you for using scvi-tools!

One way to achieve this would be to modify PyroSviTrainMixin.train to pass in the data_and_attributes argument similar to how it is done here. This dictionary should contain the registry keys and corresponding dtypes for the fields that you would like to be mini-batched, and the DataSplitter will just load those. Hope this helps!