I have a general question on how to pass the trained scvi models to downstream analysis tools, e.g. Captum (interpretation algorithm for pytorch).
These tools, e.g., captum, usually accept nn.module as input. How can I retrieve the nn.module from compact scvi models and pass it to these tools?
Relevant questionss apply to any algorithms taking nn.module or pl.lightningmodule as input. For example in “weight and bias”, wandb.watch() can take nn.module or pl.lightningmodule to log the gradients and weights during training, which can be useful to diagnose the model, but it is not clear how to do it with scvi models.