Incorporating augmentations (per batch) for scVI training

I have a module of augmentations that I would like to try with an scvi.model.SCVI model. In particular, I only want my augmentations to be used during training of this model. What is the easiest method to integrate these augmentations in the training pipeline? I know I can directly modify some of the scvi codebase but I was wondering if there may already be a simpler way?

What kind of augmentations do you have? If they are simple enough, you could potentially re-write them as Lightning callbacks and pass them in with the callbacks argument in train(). Otherwise, depending on the complexity of the augmentations, you might have to fiddle with creating custom torch modules.

Hey, thanks for getting back to me.

I’m considering augmentations that directly target the raw read count matrix such as:

  • Random masking of gene expressions in the read count matrix
  • Poisson noise added to gene expressions in the read count matrix

I’m also considering augmentations that can be applied after the log-transformation in scVI. This would be things like:

  • Gaussian noise
  • Monotonous function transformations (e.g. linear functions, power functions etc. all with randomly sampled parameters)

Note that a given transformation should occur on a per training step basis.

Currently the way I’m implementing it is a separate Augmentations Class which directly modifies a given batch in the training step. I just need to add a couple of extra lines into the “training_step” function (See: scvi-tools/scvi/train/_trainingplans.py at main · scverse/scvi-tools · GitHub). I suspect there are better ways of doing this. What do you think?

For the first set of augmentations targeting the raw read counts, I think it would be straightforward to implement those in a Lightning Callback that overrides the on_train_batch_start method, since it has access to the batch before it is passed into training_step. The only consideration would be to modify the tensors in-place as the hooks aren’t supposed to return anything. Then, you can initialize the callback and pass it into scvi.model.SCVI.train(callbacks=[...]).

For a more flexible approach, you could subclass our TrainingPlan and then override the training_step (I guess you’re doing something similar to this already).

class AugmentationsTrainingPlan(scvi.train.TrainingPlan):

    def training_step(self, batch, batch_idx):
        # modify batch here
        return super().training_step(self, batch, batch_idx)

And before initializing SCVI, you’d want to set the following:

scvi.model.SCVI._training_plan_cls = AugmentationsTrainingPlan

This would be necessary for the augmentations applied post-log1p transformation since this occurs inside the forward method.

Thanks again for the reply. Here is the approach I am now following:

class AugmentationsTrainingPlan(scvi.train.TrainingPlan):
    def __init__(self, augmentation_fn=None, **kwargs):
        super().__init__(**kwargs)
        self.augmentation_fn = augmentation_fn


    def training_step(self, batch, batch_idx):
        if self.augmentation_fn is None:
            return super().training_step(batch, batch_idx)

        # Instantiate the Augmentation class
        augmentation = Augmentation(batch)

        # Mapping of augmentation functions
        augmentation_map = {
            "random_mask": augmentation.random_mask,
            "poisson_noise": augmentation.poisson_noise,
            "random_linear_augmentation": augmentation.random_linear_augmentation
        }

        # Apply the selected augmentation
        if self.augmentation_fn in augmentation_map:
            batch = augmentation_map[self.augmentation_fn](batch)
        else:
            raise ValueError("Invalid augmentation function provided")

        # Print which augmentation has been applied
        print(f"Applied augmentation: {self.augmentation_fn}")

        # Call the parent training_step method with the augmented batch
        return super().training_step(batch, batch_idx)

I then run the following code:

from scvi.module.base import BaseModuleClass

scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="batch")
scvi.model.SCVI._training_plan_cls = AugmentationsTrainingPlan(module=BaseModuleClass())
model = scvi.model.SCVI(adata)
model.train()

I obtain the following error:


TypeError Traceback (most recent call last)

[<ipython-input-30-97f5c113344e>](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in <cell line: 4>() 2 scvi.model.SCVI._training_plan_cls = AugmentationsTrainingPlan(module=BaseModuleClass()) 3 model = scvi.model.SCVI(adata) ----> 4 model.train()

---
8 frames
---

[/usr/local/lib/python3.10/dist-packages/scvi/model/base/_training_mixin.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, load_sparse_tensor, batch_size, early_stopping, datasplitter_kwargs, plan_kwargs, data_module, **trainer_kwargs) 126 127 plan_kwargs = plan_kwargs or {} --> 128 training_plan = self._training_plan_cls(self.module, **plan_kwargs) 129 130 es = "early_stopping"

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in _wrapped_call_impl(self, *args, **kwargs) 1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(*args, **kwargs) 1519 1520 def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in _call_impl(self, *args, **kwargs) 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(*args, **kwargs) 1528 1529 try:

[/usr/local/lib/python3.10/dist-packages/scvi/train/_trainingplans.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in forward(self, *args, **kwargs) 276 def forward(self, *args, **kwargs): 277 """Passthrough to the module's forward method.""" --> 278 return self.module(*args, **kwargs) 279 280 @torch.inference_mode()

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in _wrapped_call_impl(self, *args, **kwargs) 1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(*args, **kwargs) 1519 1520 def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in _call_impl(self, *args, **kwargs) 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(*args, **kwargs) 1528 1529 try:

[/usr/local/lib/python3.10/dist-packages/scvi/module/base/_decorators.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in auto_transfer_args(self, *args, **kwargs) 30 # decorator only necessary after training 31 if self.training: ---> 32 return fn(self, *args, **kwargs) 33 34 device = list({p.device for p in self.parameters()})

[/usr/local/lib/python3.10/dist-packages/scvi/module/base/_base_module.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss) 201 another return value. 202 """ --> 203 return _generic_forward( 204 self, 205 tensors,

[/usr/local/lib/python3.10/dist-packages/scvi/module/base/_base_module.py](https://7i08l443a2h-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240214-060113_RC00_606917857#) in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss) 737 738 inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) --> 739 inference_outputs = module.inference(**inference_inputs, **inference_kwargs) 740 generative_inputs = module._get_generative_input( 741 tensors, inference_outputs, **get_generative_input_kwargs

TypeError: scvi.module.base._base_module.BaseModuleClass.inference() argument after ** must be a mapping, not NoneType

You can assume the following:

  • Each augmentation modifies batch[“X”] inplace and returns batch as expected
  • Each augmentation generates synthetic examples to replace original data such that we do not ever change the batch sizes
  • I am using scvi-colab

So my questions are:

  1. I am not sure what is going wrong here. Any ideas?
  2. If I want to choose when to apply my augmentations (e.g. whether before or after the log1p transformation) what would be the best way to proceed?

@martinkim0 @adamgayoso

For (2), if you want to apply augmentations after log1p, I think the most straightforward way to do this is to pass in log_variational=False to the model constructor, which will disable automatic log1p transforming, and then follow the same steps as for pre-log1p transformations and add the log1p step yourself in training_step.

1 Like

For (1), could you instead try putting the augmentations in the on_after_batch_transfer hook instead? The API should be similar.

Can you please clarify what you mean here?

I see there is already a definition of on_after_batch_transfer hook in the scvi.dataloaders.DataSplitter class. If my aim is to be able to define my augmentations at model initialization, I suspect there is quite a bit of work to do (but I’m happy to be proven wrong).

Sorry, what I meant was if you could replace training_step with on_after_batch_transfer as the hook for the training plan to see if this might fix the error you were running into.

I managed to solve the problem by implementing my augmentations in the on_after_batch_transfer function from the DataSplitter. I can then pass the augmentations I wish to apply onto the model.train() call.

I will try to the post-log1p transformation augmentations and let you know how it goes. Thanks again.

1 more question. I would like to apply an augmentation that’s based on the sequencing technology (i.e. 10X, Smart-seq). This information is stored in my anndata object. What would be the easiest way to pass this on to the model so that I can apply, say one augmentation to 10X data and another to smart-seq data?

Great! Glad you were able to figure it out.

If you don’t already have a batch key that you’re passing into the model, you can register the obs field with the sequencing metadata and pass it into setup_anndata(batch_key=...). Then, you can access this information in any of the Lightning hooks that have batch as an argument by accessing batch[scvi.REGISTRY_KEYS.BATCH_KEY], which will give you a tensor with integer-encoded values for the sequencing tech.

If you already have a batch key, you can pass in this information with categorical_covariate_keys in setup_anndata, and similarly access it with REGISTRY_KEYS.CAT_COVS_KEY. The only difference is that this tensor will contain integer-encoded values for all other covariates you are passing in categorical_covariate_keys (this argument takes a list of keys).

A side effect of this method is that, even if you don’t want it, the sequencing tech information will be passed into the decoder as a covariate. Disabling this might take a bit more work.

Hey @martinkim0. Thanks for all the help. I fixed the post-log1p augmentations by applying them directly in the _regular_inference( ) function from the vae module, instead of applying them in the on_after_batch_transfer data hook. I have one more issue:

Using Sequencing Tech (basically 2 categories) To Inform the Augmentation

I already have a batch key. I want to my augmentations to be applied depending on the sequencing tech (either 10X or smart-seq). I registered the sequencing tech as a new key, but I do not want the sequencing tech to be passed on to the decoder as a covariate. I just want to use this key to inform the augmentation (e.g. if the sequencing tech = 10X, multiply read counts by 0.8, otherwise multiply read counts by some other float scale factor).

I’m not sure how to register the new key, while also preventing the new categorical variable to be passed to the decoder as a covariate.

Also, the sequencing tech falls into 1 of 2 categories (10X or smart-seq 2). If the categories are then integer-encoded, how do I find out which is which? My current (but inelegant) solution is to use the fact that 10X library sizes are generally much smaller than smart-seq library sizes.