The only relevant changes should be the train and val step functions and logging in the trainingsplan:
@torch.no_grad()
def compute_and_log_metrics(
self,
loss_recorder: LossRecorder,
metrics: MetricCollection,
mode: str,
metrics_eval: Optional[dict] = None,
):
"""
Computes and logs metrics.
Parameters
----------
loss_recorder
LossRecorder object from scvi-tools module
metric_attr_name
The name of the torch metric object to use
mode
Postfix string to add to the metric name of
extra metrics
metrics_eval
Evaluation metrics given as dict name:metric_value
"""
n_obs_minibatch = loss_recorder.n_obs
loss_sum = loss_recorder.loss_sum
# use the torchmetric object
metrics.update(
loss=loss_sum,
n_obs_minibatch=n_obs_minibatch,
)
# pytorch lightning handles everything with the torchmetric object
# self.log_dict(
# metrics,
# on_step=on_step,
# on_epoch=on_epoch,
# batch_size=n_obs_minibatch,
# )
self.log(
f"loss_{mode}",
loss_recorder.loss_sum,
on_step=self.log_on_step,
on_epoch=self.log_on_epoch,
batch_size=n_obs_minibatch,
)
# accumulate extra metrics passed to loss recorder
for extra_metric in loss_recorder.extra_metric_attrs:
met = getattr(loss_recorder, extra_metric)
if isinstance(met, torch.Tensor):
if met.shape != torch.Size([]):
raise ValueError("Extra tracked metrics should be 0-d tensors.")
met = met.detach()
self.log(
f"{extra_metric}_{mode}",
met,
on_step=self.log_on_step,
on_epoch=self.log_on_epoch,
batch_size=n_obs_minibatch,
)
# accumulate extra eval metrics
if metrics_eval is not None:
for extra_metric, met in metrics_eval.items():
if isinstance(met, torch.Tensor):
if met.shape != torch.Size([]):
raise ValueError("Extra tracked metrics should be 0-d tensors.")
met = met.detach()
self.log(
f"{extra_metric}_{mode}_eval",
met,
on_step=self.log_on_step,
on_epoch=self.log_on_epoch,
batch_size=n_obs_minibatch,
)
def training_step(self, batch, batch_idx, optimizer_idx=0):
for loss, weight in self.loss_weights.items():
self.loss_kwargs.update({loss: self.compute_loss_weight(weight=weight)})
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
# self.log("train_loss", scvi_loss.loss, on_epoch=True) # Saved above via loss recorder
self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
return scvi_loss.loss
def validation_step(self, batch, batch_idx):
# loss kwargs here contains `n_obs` equal to n_training_obs
# so when relevant, the actual loss value is rescaled to number
# of training examples
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
# self.log("validation_loss", scvi_loss.loss, on_epoch=True) # Saved above via loss recorder
metrics_eval = self.module.eval_metrics()
self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation", metrics_eval=metrics_eval)