I’m working on a custom implementation with scvi-tools and have a design question regarding the best way to handle specific points selected during training.
In my setup, at the end of each epoch, I select a subset of points that will be used for loss calculation in the subsequent epoch (using the loss method of the VAE module). The selection of these points is implemented through a custom TrainingPlan.
I’m considering three possible approaches to manage these points throughout the training process:
Add them as attributes of the VAE module. For instance, during the initialization of the SCVI model, this would look like: self.module.new_points = new_pts.
Store them in adata using something like: adata.uns['_scvi_new_points'].
In this case, I believe I might also need to modify the __getitem__ method of the AnnTorchDataset class.
Which approach would you recommend? Are there any advantages, drawbacks, or potential pitfalls for each? If there’s a better alternative, I’d love to hear about it.
Hi, you would want to use a custom lightning callback for this (I understand correctly that after each epoch you select a new set of cells for the subsequent epoch). This lightning callback has to create a new datasplit and dataloader using custom indices (passed to datasplitter). I would recommend using Callback — PyTorch Lightning 2.5.0.post0 documentationon_epoch_start which sounds like the correct handle here.
Thank you for your response and suggestion! Let me clarify my use case a bit further:
The selected points are not directly used for training, so they don’t need to be included in the training data loader. Instead, these points are specifically used to calculate some parameters at the beginning of each epoch, which are then utilized in the loss calculation throughout the current epoch. Therefore, I need access to these points at the start of every epoch to compute the necessary parameters.
At the end of each epoch, these points are updated. They are generated using the decoder and exist in the data space (with a shape of (K, n_cells)), but there’s no guarantee that they correspond to actual points in the training set.
Given this setup, do you think the custom Lightning callback is still the best approach? Or would you suggest a different mechanism to manage and update these points efficiently across epochs?
Yes, everything you want to do at the beginning of each epoch (or step) should be handled by a Callback (and by lightning in general). It creates the smallest overhand and you don’t want to modify the module or the data during training.
Another option would be to generate them during each step and define it in the decoder call. If you define them per epoch, just pass them to the loss function (you can check our current kl_weight schedule on how to implement the Callback).