Mapping old scANVI code to 0.16.2

Hey!

I’m just trying to update the HLCA reference building code to 0.16.2 and it looks like there have been quite a few API changes since then. Specifically, scANVI training now has to go via an scVI model specification before a scANVI model extension rather than having the (quite useful) n_epochs_unsupervised and n_epochs_semisupervised separation. I was wondering if someone could confirm how I’ve mapped the parameters from old to new code versions. Old version should be 0.8.2 or so if I’m not mistaken.

setup_anndata should be pretty consistent with the addition of the size_factors argument which I can now pass directly.

The model setup via sca.models.SCANVI is now pretty much exactly in scvi.models.SCVI, with the exception of use_cuda which has moved to .train() and is now use_gpu.

My main question is about the .train() function, which now seems quite different. Initially I had:

vae_epochs = 500
scanvi_epochs = 200

early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

early_stopping_kwargs_scanvi = {
    "early_stopping_metric": "accuracy",
    "save_best_state_metric": "accuracy",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

vae.train(
    n_epochs_unsupervised=vae_epochs,
    n_epochs_semisupervised=scanvi_epochs,
    unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
                                       early_stopping_kwargs=early_stopping_kwargs_scanvi),
    frequency=1
)

Now, I’m trying to map just the scvi.model.SCVI part, and I think this should be:

vae_epochs = 500

early_stopping_kwargs = {
    early_stopping: True, 
    early_stopping_monitor: 'elbo_validation',
    early_stopping_patience: 10,
    early_stopping_min_delta: 0.0,
}

plan_kwargs = {
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

# Train scVI model
vae.train(
    max_epochs=vae_epochs,
    plan_kwargs = plan_kwargs,
    **early_stopping_kwargs,
    check_val_every_n_epoch=1,
    use_gpu=True
)

Could someone confirm that this should be the same specification? I’m mainly unsure about:
threshold == early_stopping_min_delta, patience == early_stopping_patience, frequency == check_val_every_n_epoch, and the separation of plan_kwargs and early stopping kwargs.

Thanks a lot!

Further questions for the scANVI model training part:

Is the old "on": "full_dataset" the same as "train_size": 1.0? And is there any way to do early stopping without a validation set for scANVI if you are only interested in the embedding? It seems all early_stopping_monitor options assume there is a validation set.

Hi Malte,

The mapping of parameters looks good to me with the exception of the on="full_dataset kwarg. If you look at the old code, it explicitly grabbed a subset of the data based on this kwarg (Old early stopping code). However, now the codebase uses pytorch lightning’s early stopping implementation which runs the check on a validation hook. It’s possible you can hijack the datasplitter such that the validation dataloader just returns the full dataset instead, so that it will be used in the early stopping callback. I believe this would match the behavior of the old code.

Hi Justin!

Thanks a lot for your answer! I see how the on keyword referred to something else than I thought now. I guess the default train_size was always 0.9 even if on="full_dataset" was used then?

Regarding hijacking the datasplitter to determine the subset picked for validation. That doesn’t seem to be quite as simple as just using train_size=1.0, validation_size=1.0 in the .train() method, due to this line:

Or did you have something else in mind? I was hoping not to have to fork the repo and change validation to train somewhere in the code.

Also, as far as I understand, I can’t recreate the old scANVI early stopping parameters anyway, no? It seems accuracy is no longer a valid metric to monitor for early stopping. Would you recommend monitoring the elbo, or just not implementing early stopping if you’re looking for an at least equivalent performance to the old accuracy-based early stopping?

Or did you have something else in mind? I was hoping not to have to fork the repo and change validation to train somewhere in the code.

Unfortunately, what I had in mind was forking and making this change in the data splitter class. There’s no easy way to do that with the current API.

Also, as far as I understand, I can’t recreate the old scANVI early stopping parameters anyway, no? It seems accuracy is no longer a valid metric to monitor for early stopping. Would you recommend monitoring the elbo, or just not implementing early stopping if you’re looking for an at least equivalent performance to the old accuracy-based early stopping?

Correct, the accuracy metric is no longer in the scanvi model (although it could be added back in then used as the early stopping metric). I would recommend still using the validation elbo with a high max epochs argument, but it’s hard to say whether that would be at least as good as using accuracy. One thing you could do without forking the repo (instead subclassing the SCANVI/SCANVAE class) is adding the accuracy metric in the LossRecorder at the end of the loss function, then changing the early_stopping_monitor to the respective metric.

In fact, there is already a classification loss metric computed in the loss that you may be able to use: [https://github.com/scverse/scvi-tools/blob/5496b993b07e94ac4a6f111589d61a34154bb126/scvi/module/_scanvae.py#L218-L238](https://classification loss). I believe you should be able to use this out of the box with early_stopping_monitor="classification_loss_validation". This term is only present when there are labelled cells in your dataset.

Tbh with the old code, I’m not sure what it means to do early stopping w.r.t. accuracy on “full_dataset” as full dataset would have been labelled and unlabelled cells. Probably a bug in the old code.

Yes, this would be like the cross entropy loss, you could also do early_stopping_monitor="classification_loss_train" if you’re using train_size=1.0

1 Like

@LuckyMD can you describe what you mean here? You don’t need to use the size factor argument there unless you specifically want to pass scran size factors, etc.

Yes, this would be like the cross entropy loss, you could also do early_stopping_monitor=“classification_loss_train” if you’re using train_size=1.0

Ooohh… I didn’t know that was an option. The docs only hinted at "classification_loss_validation" as well as other "..._validation" losses (I didn’t look at code… probably should have). That will definitely be something I’ll look to optimize for performance once I have one version of the updated code working.

@LuckyMD can you describe what you mean here?

I’m passing scran size factors, which I assume should be pretty similar to how you previously calculated size factors in scvi. As you mentioned it might be unstable, I thought that could be a better idea than setting use_observed_lib_size=False. But you’re right, being consistent with the old code would be using the latter option.