Pass scvi models to interpretation algorithm for downstream analysis

I have a general question on how to pass the trained scvi models to downstream analysis tools, e.g. Captum (interpretation algorithm for pytorch).
These tools, e.g., captum, usually accept nn.module as input. How can I retrieve the nn.module from compact scvi models and pass it to these tools?

Relevant questionss apply to any algorithms taking nn.module or pl.lightningmodule as input. For example in “weight and bias”, wandb.watch() can take nn.module or pl.lightningmodule to log the gradients and weights during training, which can be useful to diagnose the model, but it is not clear how to do it with scvi models.

The underlying nn.module is accessible with model.module. The pytorch lightning trainer, which contains the pytorch lightning module is accessible with model.trainer. In our design, the nn.module is passed to a lightning module (which we call a TrainingPlan, which is then trained with pytorch lightning.)

You can pass any custom callbacks or loggers directly to the pytorch lightning trainer in our model’s train method, through trainer_kwargs.

I can further clarify any points of confusion, but there should be fairly easy access to regular pytorch lightning features.

1 Like