I have a general question on how to pass the trained scvi models to downstream analysis tools, e.g. Captum (interpretation algorithm for pytorch).
These tools, e.g., captum, usually accept nn.module
as input. How can I retrieve the nn.module
from compact scvi models and pass it to these tools?
Relevant questionss apply to any algorithms taking nn.module
or pl.lightningmodule
as input. For example in “weight and bias”, wandb.watch()
can take nn.module
or pl.lightningmodule
to log the gradients and weights during training, which can be useful to diagnose the model, but it is not clear how to do it with scvi models.