Freezing layers: register_hook vs. requires_grad


Looking at the way scArches module freezing is implemented at scvi-tools/ at 096099e34568b1bf94aad5273b9f92c202c9c755 · scverse/scvi-tools · GitHub

    for key, mod in module.named_modules():
        # skip over protected modules
        if key.split(".")[0] in mod_no_hooks_yes_grad:
        if isinstance(mod, FCLayers):
            hook_first_layer = False if no_hook_cond(key) else True

It looks like set_online_update_hooks sets the (backwards?) gradients to zero and freezes FCLayers modules. If so, why do this in addition to the usual requires_grad=False setting for freezing layers? If not, what is it doing in the above call?

Thank you in advance for teaching me,

The issue is the input dims related to non-batch and batch categorical information are in the same linear layer. Decomposing that slows down the forward pass of the model in PyTorch (though Jax could make this simpler).

Therefore, the gradient hook is only applied to the appropriate input dims.