Early stopping debug

Hi, I have a user running the latest scvi on a scrna dataset. For some reason the model triggers and early stopping that i can’t quite understand ( i was lucky enough to never have an early stopping occur until now!)

2024-04-25 17:20:28,337: INFO - AnnData object with n_obs × n_vars = 32322 × 22512
    obs: 'sample_id', 'tissue', 'condition', 'batch', 'subject', 'CMV_status', 'doublet_scores', 'predicted_doublets', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_rp', 'log1p_total_counts_rp', 'pct_counts_rp', 'MarkersNeutro_score', 'S_score', 'G2M_score', 'phase'
    var: 'gene_ids', 'feature_types', 'genome', 'hb', 'mt', 'rp', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection', 'hvg_exclude', 'highly_variable_rank'
    uns: 'hvg', 'log1p', 'pca', 'sample_id_colors'
    obsm: 'X_pca'
    varm: 'PCs'
    layers: 'logged_counts', 'raw_counts'
2024-04-25 17:20:28,338: INFO - raw counts found in layer
2024-04-25 17:20:28,845: INFO - Unable to initialize backend 'cuda': 
2024-04-25 17:20:28,846: INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2024-04-25 17:20:28,872: INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-04-25 17:20:28,872: WARNING - An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
{'gene_likelihood': 'zinb'}
{'max_epochs': 400, 'train_size': 0.9, 'early_stopping': True}
{'lr': 0.001, 'n_epochs_kl_warmup': 400, 'reduce_lr_on_plateau': True, 'lr_patience': 8, 'lr_factor': 0.1}

Training:   0%|                                                                                                                  | 0/400 [00:00<?, ?it/s]
Epoch 1/400:   0%|                                                                                                               | 0/400 [00:00<?, ?it/s]
Epoch 1/400:   0%|▎                                                                                                    | 1/400 [00:11<1:14:12, 11.16s/it]
Epoch 1/400:   0%|                                                 | 1/400 [00:11<1:14:12, 11.16s/it, v_num=1, train_loss_step=263, train_loss_epoch=314]
Epoch 2/400:   0%|                                                 | 1/400 [00:11<1:14:12, 11.16s/it, v_num=1, train_loss_step=263, train_loss_epoch=314]
Epoch 2/400:   0%|▎                                                  | 2/400 [00:12<36:42,  5.53s/it, v_num=1, train_loss_step=263, train_loss_epoch=314]
Epoch 2/400:   0%|▎                                                  | 2/400 [00:12<36:42,  5.53s/it, v_num=1, train_loss_step=285, train_loss_epoch=280]
Epoch 3/400:   0%|▎                                                  | 2/400 [00:12<36:42,  5.53s/it, v_num=1, train_loss_step=285, train_loss_epoch=280]
Epoch 3/400:   1%|▍                                                  | 3/400 [00:14<24:39,  3.73s/it, v_num=1, train_loss_step=285, train_loss_epoch=280]
Epoch 3/400:   1%|▍                                                  | 3/400 [00:14<24:39,  3.73s/it, v_num=1, train_loss_step=281, train_loss_epoch=275]
Epoch 4/400:   1%|▍                                                  | 3/400 [00:14<24:39,  3.73s/it, v_num=1, train_loss_step=281, train_loss_epoch=275]
Epoch 4/400:   1%|▌                                                  | 4/400 [00:15<18:57,  2.87s/it, v_num=1, train_loss_step=281, train_loss_epoch=275]

[...]

Epoch 394/400:  98%|██████████████████████████████████████████████▎| 394/400 [10:18<00:09,  1.57s/it, v_num=1, train_loss_step=271, train_loss_epoch=267]
Monitored metric elbo_validation did not improve in the last 45 records. Best score: 271.121. Signaling Trainer to stop.

Can you give me a couple of pointers to:

  • what checkpoints saving should i enable and check to understand what’s happening?
  • general recommendation? increase the patience and/or learning rate? i see the training losses are not really decreasing that much over time.

Code at: panpipes/panpipes/python_scripts/batch_correct_scvi.py at main · DendrouLab/panpipes · GitHub

thank you!

I’m not understanding the issue here. Is the model still improving and is stopping early? Or is an optimal model found and then training stops? The second behavior is good behavior and if you don’t want to have it, then disable early_stopping would be recommended. Scvi actually prints the reason for stopping (elbo_validation did not improve).