Given that these are directly added to the loss function, isn’t this saying that we want P(z1| pz1) to be high and P(z1| qz1) to be low? But z1 is generated (via reparameterization trick) from qz1 so how is this possible?
Additionally: I presume understanding those two losses would also answer why the encoder_z2_z1 and decoder_z1_z2 part of the architecture is necessary. If not, what is this additional encoder and decoder used for?
Thank you for using scvi-tools. These lines are easier to understand as the KL divergence term for z_1. Specifically KL(q_\eta(z)||p_\theta(z)) = loss_z1_weight + loss_z1_unweight. The second encoder and decoder pair are additional networks that break down z_1 per cell type. This is the defining difference between scVI and scANVI. The math is detailed in our user guide: scANVI - scvi-tools. Let me know if this helps!
Hi @Justin_Hong , can I follow up on this and ask a few questions about how z2 is used. From what I got from the code, z2 or rather the mean and variance of encoder_z2_z1 are only used to calculate the KL loss. But to calculate the reconstruction and the classification loss, z1, i.e. the direct output of the encoder, is used. Could you please clarify if that’s correct? If z2 is only used in the KL calculation, why is it needed in the first place? Thanks!