I am trying to understand the adversarial training in the GIMVItraining plan. If I understand it correctly, the purpose of this is to encourage latent space mixing, e.g., the latent code learnt from adata_seq can also represent adata-spatial information and vice versa.
In your implementation below, you trained a classifier to predict the “adversarial” mode label. But it won’t make a difference to just use the true mode label as it is a binary classification, right? (i.e., “predicting all seq samples to 0, spatial samples to 1” is equivalent to “predicting all seq samples to 1, spatial samples to 0”)
Would it make more sense to use the uniform logits (0.5, 0.5) instead of (0,1) so that latent space can not “distinguish” which mode it is from?
# fool classifier if doing adversarial training
batch_tensor = [
torch.zeros((z.shape, 1), device=z.device) + i
for i, z in enumerate(zs)
if kappa > 0 and self.adversarial_classifier is not False:
fool_loss = self.loss_adversarial_classifier(
torch.cat(zs), torch.cat(batch_tensor), False
loss += fool_loss * kappa
Thanks for your explanation. I understand how it is implemented in the adversarial training plan as you pointed out.
My point was more related to its application in GIMVItrainingplan.
In GIMVI, the labels refer to either “seq” or “spatial”, so it doesn’t matter if the classifier is trained to predict the wrong label or true modality as there are only two labels.
Can you elaborate on your last point? Is it necessary to have both classifiers trained to achieve the latent space mixing, or only one is sufficient?
The GIMVI training plan effectively does the same steps. First the encoder tries to represent the cells in a way that fools the classifier (classifier is frozen here), and then the same exact classifier is trainer to correctly predict the correct modality.