About VAE._regular_inference()

Hi all,

I’m having trouble following the data flow in the simplest scVI case. As defined in BaseModuleClass._generic_forward(), the model will go first through an inference face (encoder) to then go through a generative process (decoder).

However, what I don’t seem to understand is that in VAE._regular_inference(), which I understand is the first step during training, batch_index is passed to the different encoders. The same applies to VAE.generative() and corresponding decoder(s). batch_index is usually obtained by slicing the tensors as tensor[REGISTRY_FIELDS.BATCH_KEY].

  • My first question is: what does batch_index contain exactly?

  • Then my second question would be: why is this variable passed to the different encoders in VAE._regular_inference()? I understand why we are passing x and *categorical_input, but I get lost trying to understand where batch_index comes into play in the codebase. Maybe a pointer can be helpful?

Let me know if something is not clear!

First, let’s look at setup_anndata

We see that the batch_key in .obs will be available via the REGISTRY_KEYS.BATCH_KEY key.

Now let’s look at where the data gets processed:

Thus, batch_index corresponds to the batch information (for integration purposes, nothing to do with minibatching).

Now let’s look where it’s used in _regular_inference:

It’s passed to the z_encoder, which was initialized as such:

Notice that encoder_cat_list is None by default.

By continuing to follow the code, you’ll find that when this is None, no categorical covariates passed in are actually used.

This is not the case for the decoder, which does use the batch information.

The other categorical inputs are also one-hotted and concatenated, but this depends on the initial n cat list passed in.

Finally, I know this is confusing, and this is legacy code we had to keep in to keep old saved models compatible. We are refactoring some parts of the codebase with jax, and those implementations will be a lot cleaner.

We also keep a basic scvi implementation here:

Hi Adam,

Thanks for the prompt response.

Maybe not the best example since no batch information is concatenated by default in the encoder, but the same would apply to the decoder.

From what I understand:

In qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) :

  • *categorical_input will be unpacked into VAE.__init__(n_cats_per_cov)
  • batch_index will be passed into VAE.__init__(n_batch), which will default to 0 if no REGISTRY_KEYS.BATCH_KEY have been initiated

Again, thanks for the prompt response!