After training, I am able to view the losses that were computed by running:
>>> print(model.trainer.logger.history.items())
dict_items([('train_loss_step', train_loss_step
epoch
0 316.210358
1 295.575592
2 304.064575
3 294.447571
4 288.721558
5 291.47522
6 297.350006
7 290.012665
8 283.560333
9 281.745026), ('train_loss_epoch', train_loss_epoch
epoch
0 378.040009
1 314.651703
2 304.716156
3 298.997986
4 295.423248
5 292.827759
6 290.723419
7 289.18985
8 288.019775
9 286.723969), ('elbo_train', elbo_train
epoch
0 397.563904
1 345.731934
2 339.36911
3 334.903931
4 331.399536
5 328.528107
6 325.961426
7 323.683929
8 321.890503
9 320.092224), ('kl_global_train', kl_global_train
epoch
0 0.0
1 0.0
2 0.0
3 0.0
4 0.0
5 0.0
6 0.0
7 0.0
8 0.0
9 0.0), ('kl_local_train', kl_local_train
epoch
0 19.523869
1 31.158157
2 34.827053
3 36.177135
4 36.339882
5 36.152225
6 35.77433
7 35.108616
8 34.561852
9 34.136429), ('reconstruction_loss_train', reconstruction_loss_train
epoch
0 378.040039
1 314.573761
2 304.542053
3 298.726807
4 295.059662
5 292.375885
6 290.187103
7 288.575317
8 287.328644
9 285.955811)])
vscode ➜ /workspaces/scvi_work_pip (master ✗) $
However, during training, the progress bar does not report the loss:
Epoch 4/10: 30%|██████████████████████████████████████████████████████ | 3/10 [00:14<00:31, 4.55s/it, v_num=1]
The training loss is recorded on the following lines of a custom training plan, which uses some manual optimizers:
def training_step(self, batch, batch_idx):
...
_, _, vae_losses = self.forward(batch, loss_kwargs=self.loss_kwargs)
regularized_vae_loss = vae_losses.loss + laplace_alpha
vae_opt.zero_grad()
self.manual_backward(regularized_vae_loss)
vae_opt.step()
self.log("train_loss", regularized_vae_loss, on_epoch=True)
self.compute_and_log_metrics(vae_losses, self.train_metrics, "train")
return regularized_vae_loss
Note that the novel regularizer laplace_alpha
which I compute during training_step()
shouldn’t be applied during post-training inference, so module.forward()
omits it completely. For now, module.forward()
is just inherited from BaseModuleClass
.