Overriding functions from submodule

Hello!

I am using the conditional Variational auto-encoder model (scvi.module.VAEC) and would like to adjust its functionalities.

First of all, I would like to return additional loss values to the ones that are currently passed to the LossRecorder. In the documentation (scvi.module.base.LossRecorder — scvi-tools) it states that “Additional metrics can be passed as keyword arguments”. However, if I e.g. pass another torch.tensor(0.0) then I get the following error: TypeError: __init__() takes from 2 to 5 positional arguments but 6 were given . So my first question I: In what format does the LossRecorder expect the additional arguments?

Secondly, I would like to adjust functionalities from submodules inhereted by VAEC. For example, I would like to change the training_epoch_end(self,output) from the TrainingPlan. Overriding the function in my ModelClass(UnsupervisedTrainingMixin, BaseModelClass)` where I initialise VAEC did not work. Do you know how I can easily override functions in the TrainingPlan?

class MyModelClass(UnsupervisedTrainingMixin, BaseModelClass):
    """Defines the train/val/ood methods"""

    def __init__(
            self,
            adata: AnnData,
            dataset,
            args
    ):
        super(MyModelClass, self).__init__(adata)
        self.module = scvi.module.VAEC()
        ...

     def training_epoch_end(self, outputs):
        super(MyModelClass, self).training_epoch_end(outputs)
        ...
...

Can you share how you’re using it exactly?

You can, but you need to write a new class called MyTrainingPlan(TrainingPlan) (inherits from our training plan) and then you can just write that one function there. You’d also need to write your own train method in MyModelClass that uses your new training plan (mostly copy paste of our method).

I am returning the LossRecorder in the loss() function as has been done in scvi.module.CVAE. For testing I just tried to return another 0 torch tensor as is being done for kl_global variable:

LossRecorder(loss, reconst_loss, kl_divergence_z, torch.tensor(0.0), torch.tensor(0.0))

Thanks this works! However, the initialisation of the TrainingPlan from the train() function provided in: https://docs.scvi-tools.org/en/stable/tutorials/notebooks/model_user_guide.html uses the len(data_splitter.train_idx). I had to remove this input as I got the error that data_splitter does not have a object train_idx. This seems to be correct as train_idx is also not listed as attribute in the DataSplitter documentation (scvi.dataloaders.DataSplitter — scvi-tools). I am not sure if this is a mistake on the documentation or if something went wrong for me when transferring the functions.

You will need to use keyword arguments like

LossRecorder(loss, reconst_loss, kl_divergence_z, metric1=torch.tensor(0.0), metric2=torch.tensor(0.0))

You can write your own train method that does not use our TrainRunner object which would avoid this problem, or set .train_idx = [1] if you’re not using it.

Please let me know if you have further questions!

Thanks again, I am able to record different loss metrics :slightly_smiling_face:

I followed the Constructing a high-level model — scvi-tools tutorial with the vaec module instead of scvi. It seems like that my model does not perform any validation because no validation loss is included in the model.history. I did declare the training and validation size in the datasplitter (train_size = 0.8 and validation_size=1-train_size). Do you know why my model is not performing any validation?

On a different note: Where in the code can I influence when and with what data the training, validation and testing is performed? How I understand it is that the model currently takes the AnnData set and randomly splits it into train and validation set. While it is irrelevant for me which data is being taken for training/validation I would like to influence the model in which data is being used for testing, e.i. I only want to test it for a specific cell type A and leaving this cell type A out during training and testing.

you need to add check_val_every_n_epoch=1 for example to the train command, you can use all the train arguments of the pytorch lightning trainer

you can train with an anndata without this cell type and then run one of the get_... methods with the other held out anndata.