Understanding gimVI loss logits

Thanks for sharing the idea and codes of gimVI. I am trying to understand the loss from the code here, and please correct me if I am wrong here.

  1. Lets say we have a sequence data denoted as seq (e.g. n* 3000), and spatial data denoted as spa (e.g. m*500).
  2. first we will have a joint-encoder to inference z_mu and z_var. Lets say z_mu, z_var, z ← Enc1(seq, spa), and we will have a library for seq data, which should be l_mu, l_var, l ← Enc2(seq)
  3. With z, l, we can have the seq_px, seq_px_r, seq_p_dropout ← Dec(z, l) for sequence data, seq_recon ← ZINB(seq_px, seq_px_r, seq_p_dropout)
  4. Similarly, we can have spa_recon ← NB(spa_px, spa_px_r), where spa_px, spa_px_r, comes from the shared weight decoder Dec(z,l). and spa didn’t need to calculate the library, so we can use log(x) instead.

My questions are:

  1. what is the final loss.
    for sequence data, it should be:
    seq_loss = -seq_recon + kl(seq_z, normal(0,1)) + kl(seq_l, noraml(seq_x.mean, seq_x.var().sqrt() )).
    And spatial loss should be:
    spa_loss = -spa_recon + kl(spa_z, normal(0,1))
    Does the final_loss = seq_loss + spa_loss

  2. why we need an adv_loss here.

Thanks for your time and looking forward to your reply.