Hey!
I’m just trying to update the HLCA reference building code to 0.16.2 and it looks like there have been quite a few API changes since then. Specifically, scANVI training now has to go via an scVI model specification before a scANVI model extension rather than having the (quite useful) n_epochs_unsupervised
and n_epochs_semisupervised
separation. I was wondering if someone could confirm how I’ve mapped the parameters from old to new code versions. Old version should be 0.8.2 or so if I’m not mistaken.
setup_anndata
should be pretty consistent with the addition of the size_factors
argument which I can now pass directly.
The model setup via sca.models.SCANVI
is now pretty much exactly in scvi.models.SCVI
, with the exception of use_cuda
which has moved to .train()
and is now use_gpu
.
My main question is about the .train()
function, which now seems quite different. Initially I had:
vae_epochs = 500
scanvi_epochs = 200
early_stopping_kwargs = {
"early_stopping_metric": "elbo",
"save_best_state_metric": "elbo",
"patience": 10,
"threshold": 0,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
"early_stopping_metric": "accuracy",
"save_best_state_metric": "accuracy",
"on": "full_dataset",
"patience": 10,
"threshold": 0.001,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
vae.train(
n_epochs_unsupervised=vae_epochs,
n_epochs_semisupervised=scanvi_epochs,
unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
early_stopping_kwargs=early_stopping_kwargs_scanvi),
frequency=1
)
Now, I’m trying to map just the scvi.model.SCVI part, and I think this should be:
vae_epochs = 500
early_stopping_kwargs = {
early_stopping: True,
early_stopping_monitor: 'elbo_validation',
early_stopping_patience: 10,
early_stopping_min_delta: 0.0,
}
plan_kwargs = {
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
# Train scVI model
vae.train(
max_epochs=vae_epochs,
plan_kwargs = plan_kwargs,
**early_stopping_kwargs,
check_val_every_n_epoch=1,
use_gpu=True
)
Could someone confirm that this should be the same specification? I’m mainly unsure about:
threshold
== early_stopping_min_delta
, patience
== early_stopping_patience
, frequency
== check_val_every_n_epoch
, and the separation of plan_kwargs
and early stopping kwargs
.
Thanks a lot!