Loss_z1_weight and loss_z1_unweight in scANVI

In the scANVI implementation, there are two losses I don’t quite understand: scvi-tools/_scanvae.py at d636093bc8d49c8e03fcb4bc0a8bc8130cb29fe2 · scverse/scvi-tools · GitHub

loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)

Given that these are directly added to the loss function, isn’t this saying that we want P(z1| pz1) to be high and P(z1| qz1) to be low? But z1 is generated (via reparameterization trick) from qz1 so how is this possible?

Additionally: I presume understanding those two losses would also answer why the encoder_z2_z1 and decoder_z1_z2 part of the architecture is necessary. If not, what is this additional encoder and decoder used for?

Thanks in advance,

Hi Yuge,

Thank you for using scvi-tools. These lines are easier to understand as the KL divergence term for z_1. Specifically KL(q_\eta(z)||p_\theta(z)) = loss_z1_weight + loss_z1_unweight. The second encoder and decoder pair are additional networks that break down z_1 per cell type. This is the defining difference between scVI and scANVI. The math is detailed in our user guide: scANVI - scvi-tools. Let me know if this helps!