I’m having trouble following the data flow in the simplest scVI case. As defined in BaseModuleClass._generic_forward(), the model will go first through an inference face (encoder) to then go through a generative process (decoder).
However, what I don’t seem to understand is that in VAE._regular_inference(), which I understand is the first step during training, batch_index is passed to the different encoders. The same applies to VAE.generative() and corresponding decoder(s). batch_index is usually obtained by slicing the tensors as tensor[REGISTRY_FIELDS.BATCH_KEY].
My first question is: what does batch_index contain exactly?
Then my second question would be: why is this variable passed to the different encoders in VAE._regular_inference()? I understand why we are passing x and *categorical_input, but I get lost trying to understand where batch_index comes into play in the codebase. Maybe a pointer can be helpful?
We see that the batch_key in .obs will be available via the REGISTRY_KEYS.BATCH_KEY key.
Now let’s look at where the data gets processed:
Thus, batch_index corresponds to the batch information (for integration purposes, nothing to do with minibatching).
Now let’s look where it’s used in _regular_inference:
It’s passed to the z_encoder, which was initialized as such:
Notice that encoder_cat_list is None by default.
By continuing to follow the code, you’ll find that when this is None, no categorical covariates passed in are actually used.
This is not the case for the decoder, which does use the batch information.
The other categorical inputs are also one-hotted and concatenated, but this depends on the initial n cat list passed in.
Finally, I know this is confusing, and this is legacy code we had to keep in to keep old saved models compatible. We are refactoring some parts of the codebase with jax, and those implementations will be a lot cleaner.