Implementing Warmup to a component based on epoch

Hi,

This might be a silly question, but I’m trying to implement a warmup schedule for one of the components of my model.

I’ve implemented some make-shift solution, but was hoping to avoid digging through the codebase if someone already knows the clean way to do this.

For example, is there a straightforward way to access the current epoch from within the module class during training?

Thanks in advance!

The correct PyTorch manner is a custom callback with functions like on_epoch_end. For many use cases, it should suffice to just reuse the kl warmup value though.

1 Like

Hi, yeah, kl weight was my make-shift solution. Thanks, I will check the existing callback functions :slight_smile: