Understanding GIMVI trainingplan

Hi,

I am trying to understand the adversarial training in the GIMVItraining plan. If I understand it correctly, the purpose of this is to encourage latent space mixing, e.g., the latent code learnt from adata_seq can also represent adata-spatial information and vice versa.

  1. In your implementation below, you trained a classifier to predict the “adversarial” mode label. But it won’t make a difference to just use the true mode label as it is a binary classification, right? (i.e., “predicting all seq samples to 0, spatial samples to 1” is equivalent to “predicting all seq samples to 1, spatial samples to 0”)
  2. Would it make more sense to use the uniform logits (0.5, 0.5) instead of (0,1) so that latent space can not “distinguish” which mode it is from?
# fool classifier if doing adversarial training
            batch_tensor = [
                torch.zeros((z.shape[0], 1), device=z.device) + i
                for i, z in enumerate(zs)
            ]
            if kappa > 0 and self.adversarial_classifier is not False:
                fool_loss = self.loss_adversarial_classifier(
                    torch.cat(zs), torch.cat(batch_tensor), False
                )
                loss += fool_loss * kappa

What is the true mode label? Yes, it doesn’t matter if spatial is 0 and seq is 1.

So what is happening in the code is that here:

The classifier is being trained to predict the wrong label. In this implementation, it’s trying to maximize the probability it gives to all other labels except the correct one.

And in this part of the code:

The classifier is being separately trained to correctly predict the modality. Note that z is detached here.

So this is how it forces the latent space to mix better. Please let me know if there is any other confusion.

Thanks for your explanation. I understand how it is implemented in the adversarial training plan as you pointed out.

My point was more related to its application in GIMVItrainingplan.
In GIMVI, the labels refer to either “seq” or “spatial”, so it doesn’t matter if the classifier is trained to predict the wrong label or true modality as there are only two labels.

  • Can you elaborate on your last point? Is it necessary to have both classifiers trained to achieve the latent space mixing, or only one is sufficient?

The GIMVI training plan effectively does the same steps. First the encoder tries to represent the cells in a way that fools the classifier (classifier is frozen here), and then the same exact classifier is trainer to correctly predict the correct modality.

Thanks! I get it now.