Scvi-toools based model does not log metrics on every step

I made a new scvi-tools based model and want to log both validation and training loss on every step rather than epoch. I made it work for val metrics but for training it still reports per epoch. I adapted TrainingPlan by setting logging on step with below params:

  on_step=True,
  on_epoch=False,

And I pass corresponding params to train function:

  log_every_n_steps=1,
  check_val_every_n_epoch=1,
  val_check_interval=1,

Any idea why my train metric is still logged only once per epoch and not step? I think it also isnt due to some custom on end of epoch calls, or at least I did not find any in the scvi-tools basic code so far that I use.

Also, if I set both on_step and on_epoch true I get loss_val_step (n steps), loss_val_epoch (n epochs), loss_train_epoch (n epochs), loss_train_step (also n epochs instead of steps, also index name is epoch in df).

Can you provide more detail of what exact changes you made?

The only relevant changes should be the train and val step functions and logging in the trainingsplan:

    @torch.no_grad()
    def compute_and_log_metrics(
            self,
            loss_recorder: LossRecorder,
            metrics: MetricCollection,
            mode: str,
            metrics_eval: Optional[dict] = None,
    ):
        """
        Computes and logs metrics.

        Parameters
        ----------
        loss_recorder
            LossRecorder object from scvi-tools module
        metric_attr_name
            The name of the torch metric object to use
        mode
            Postfix string to add to the metric name of
            extra metrics
        metrics_eval
            Evaluation metrics given as dict name:metric_value
        """
        n_obs_minibatch = loss_recorder.n_obs
        loss_sum = loss_recorder.loss_sum

        # use the torchmetric object
        metrics.update(
            loss=loss_sum,
            n_obs_minibatch=n_obs_minibatch,
        )
        # pytorch lightning handles everything with the torchmetric object
        # self.log_dict(
        #     metrics,
        #     on_step=on_step,
        #     on_epoch=on_epoch,
        #     batch_size=n_obs_minibatch,
        # )

        self.log(
            f"loss_{mode}",
            loss_recorder.loss_sum,
            on_step=self.log_on_step,
            on_epoch=self.log_on_epoch,
            batch_size=n_obs_minibatch,
        )

        # accumulate extra metrics passed to loss recorder
        for extra_metric in loss_recorder.extra_metric_attrs:
            met = getattr(loss_recorder, extra_metric)
            if isinstance(met, torch.Tensor):
                if met.shape != torch.Size([]):
                    raise ValueError("Extra tracked metrics should be 0-d tensors.")
                met = met.detach()
            self.log(
                f"{extra_metric}_{mode}",
                met,
                on_step=self.log_on_step,
                on_epoch=self.log_on_epoch,
                batch_size=n_obs_minibatch,
            )
        # accumulate extra eval metrics
        if metrics_eval is not None:
            for extra_metric, met in metrics_eval.items():
                if isinstance(met, torch.Tensor):
                    if met.shape != torch.Size([]):
                        raise ValueError("Extra tracked metrics should be 0-d tensors.")
                    met = met.detach()
                self.log(
                    f"{extra_metric}_{mode}_eval",
                    met,
                    on_step=self.log_on_step,
                    on_epoch=self.log_on_epoch,
                    batch_size=n_obs_minibatch,
                )

    def training_step(self, batch, batch_idx, optimizer_idx=0):
        for loss, weight in self.loss_weights.items():
            self.loss_kwargs.update({loss: self.compute_loss_weight(weight=weight)})
        _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
        # self.log("train_loss", scvi_loss.loss, on_epoch=True)  # Saved above via loss recorder
        self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
        return scvi_loss.loss

    def validation_step(self, batch, batch_idx):
        # loss kwargs here contains `n_obs` equal to n_training_obs
        # so when relevant, the actual loss value is rescaled to number
        # of training examples
        _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
        # self.log("validation_loss", scvi_loss.loss, on_epoch=True)  # Saved above via loss recorder
        metrics_eval = self.module.eval_metrics()
        self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation", metrics_eval=metrics_eval)