Progress not reporting loss from custom training plan

After training, I am able to view the losses that were computed by running:

>>> print(model.trainer.logger.history.items())
dict_items([('train_loss_step',       train_loss_step
epoch                
0          316.210358
1          295.575592
2          304.064575
3          294.447571
4          288.721558
5           291.47522
6          297.350006
7          290.012665
8          283.560333
9          281.745026), ('train_loss_epoch',       train_loss_epoch
epoch                 
0           378.040009
1           314.651703
2           304.716156
3           298.997986
4           295.423248
5           292.827759
6           290.723419
7            289.18985
8           288.019775
9           286.723969), ('elbo_train',        elbo_train
epoch            
0      397.563904
1      345.731934
2       339.36911
3      334.903931
4      331.399536
5      328.528107
6      325.961426
7      323.683929
8      321.890503
9      320.092224), ('kl_global_train',       kl_global_train
epoch                
0                 0.0
1                 0.0
2                 0.0
3                 0.0
4                 0.0
5                 0.0
6                 0.0
7                 0.0
8                 0.0
9                 0.0), ('kl_local_train',       kl_local_train
epoch               
0          19.523869
1          31.158157
2          34.827053
3          36.177135
4          36.339882
5          36.152225
6           35.77433
7          35.108616
8          34.561852
9          34.136429), ('reconstruction_loss_train',       reconstruction_loss_train
epoch                          
0                    378.040039
1                    314.573761
2                    304.542053
3                    298.726807
4                    295.059662
5                    292.375885
6                    290.187103
7                    288.575317
8                    287.328644
9                    285.955811)])
vscode ➜ /workspaces/scvi_work_pip (master ✗) $ 

However, during training, the progress bar does not report the loss:

Epoch 4/10:  30%|██████████████████████████████████████████████████████                                                                                                                              | 3/10 [00:14<00:31,  4.55s/it, v_num=1]

The training loss is recorded on the following lines of a custom training plan, which uses some manual optimizers:

    def training_step(self, batch, batch_idx):

        ...

        _, _, vae_losses = self.forward(batch, loss_kwargs=self.loss_kwargs)
        regularized_vae_loss = vae_losses.loss + laplace_alpha

        vae_opt.zero_grad()
        self.manual_backward(regularized_vae_loss)
        vae_opt.step()

        self.log("train_loss", regularized_vae_loss, on_epoch=True)
        self.compute_and_log_metrics(vae_losses, self.train_metrics, "train")
        
        return regularized_vae_loss

Note that the novel regularizer laplace_alpha which I compute during training_step() shouldn’t be applied during post-training inference, so module.forward() omits it completely. For now, module.forward() is just inherited from BaseModuleClass.

I believe you want the prog_bar option here:

https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#log

1 Like

Thanks @adamgayoso!, I’ve confirmed that this works great for the following two cases:

  1. calling self.log(prog_bar=True) in TrainingPlan.training_step
  2. calling self.log_dict(prog_bar=True) in TrainingPlan.compute_and_log_metrics

But I now have two more questions:

  1. I’ve traced both: the call to LightningModule.log_dict() from scvi.train.TrainingPlan.compute_and_log_metrics(), and the call to LightningModule.log() from scvi.train.TrainingPlan.training_step(), and noticed that prog_bar=False in both cases, yet the loss is reported in progress. How is this achieved?
  2. What is the reason for calling self.log("train_loss", scvi_loss.loss, on_epoch=True) in scvi.train.TrainingPlan.training_step(), when the next line already passes the whole dict (including scvi_loss.loss) to LightningModule.log_dict() via compute_and_log_metrics()?