In the scANVI implementation, there are two losses I don’t quite understand: scvi-tools/_scanvae.py at d636093bc8d49c8e03fcb4bc0a8bc8130cb29fe2 · scverse/scvi-tools · GitHub
loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)
Given that these are directly added to the loss function, isn’t this saying that we want P(z1| pz1)
to be high and P(z1| qz1)
to be low? But z1
is generated (via reparameterization trick) from qz1
so how is this possible?
Additionally: I presume understanding those two losses would also answer why the encoder_z2_z1
and decoder_z1_z2
part of the architecture is necessary. If not, what is this additional encoder and decoder used for?
Thanks in advance,
Yuge