scVI reproducibility seed issue

I have seen a few similar issues on discourse but not quite exactly the same.

Background:
I have been using scVI across two Jupyter notebooks within the same conda environment.
The first notebook is to run the top 3 models identified using autotune. The second notebook is to run the model of choice.

Issue:
I have noticed that both scVI models, despite setting seeding and having identical parameters, appear different. The cluster node used for both were identical.

I ran this again with the following code across two notebooks and still had different scVI outputs.

import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager
import seaborn as sns
import anndata
import scvi
import scanpy as sc

scvi.settings.seed = 9627

adata = sc.read_h5ad(data_dir + 'postQC_CSF_atlas.h5ad')

with open(data_dir + 'hvg4000.txt') as f:
    hvg4000 = f.read().splitlines() 

adata_ref = adata_ref[:, hvg4000].copy()

early_stopping_kwargs = {
    "early_stopping": True,
    "early_stopping_monitor": "elbo_validation",
    "early_stopping_patience": 10,
    "early_stopping_min_delta": 0.001,
    }

plan_kwargs = {
    "reduce_lr_on_plateau": False,
    "lr_patience": 10,
    'lr': 0.0011
}


zinb_params = {
"model4": dict(
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    gene_likelihood="nb", 
    deeply_inject_covariates=False,
    n_hidden = 128,
    n_layers=2,
	n_latent = 20,
    dispersion= "gene-cell"
),
"model6": dict(
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    gene_likelihood="zinb", 
    deeply_inject_covariates=False,
    n_hidden = 128,
    n_layers=2,
	n_latent = 30,
    dispersion= "gene-cell"
)
    
}

scvi.model.SCVI.setup_anndata(adata_ref, layer='counts', batch_key = "batch")

zinb_models = {}

for key in zinb_params:
    scvi.settings.seed = 9627
    zinb_models[key] = scvi.model.SCVI(adata_ref,**zinb_params[key]) 
    zinb_models[key].train(
        max_epochs=500,
        batch_size = 1024,#
        plan_kwargs = plan_kwargs,
        check_val_every_n_epoch=1,
    	use_gpu = 0,
        **early_stopping_kwargs,
        )

## generate latent embedding
for key in zinb_models:
    scvi.settings.seed = 9627
    scvi_key = "X_scVI_zinb_"+key
    umap_key = "X_umap_zinb_"+key
    adata_ref.obsm[scvi_key] = zinb_models[key].get_latent_representation()
    ## create neighbourhood graph
    sc.pp.neighbors(adata_ref, use_rep=scvi_key, metric="cosine",n_neighbors=30)

    ##run umap
    sc.tl.umap(adata_ref)
    adata_ref.obsm[umap_key] = adata_ref.obsm["X_umap"].copy()
    

##plot
for key in zinb_models:
    umap_key = "X_umap_zinb_"+key
    sc.pl.embedding(
        adata_ref,
        basis = umap_key,
        color=["X_10x_kit","Dataset","Disease", "batch","Source","QC_label", "total_counts", "n_genes_by_counts", "celltype_l2"],
        frameon=False,
        title = key
        #ncols=1,
    )

@Justin_Hong, I saw you previously addressed an issue about needing to set seeds within a loop, but I have also tried setting seed on two different models without loops and the same thing happens.

It is a bit concerning as I often run the models interactively before setting the notebook to run to completion in order to timestamp/checkpoint the analysis.

For reference, this issue did not seem to occur before release 1.0.0, where scVI use to automatically set the global seed, and if I revert back, the models remain reproducible.

Thanks for flagging this @Nusob888. Unfortunately I’ve been more out of the loop with recent changes. @martinkim0 do you know what may have caused this? Perhaps it is related to the change for the jax random seed?

Hi @Nusob888, thanks for bringing up this issue. It’s not immediately clear to me why this is occurring as the recent changes have only addressed setting a default seed, so manually setting your seed should be fine.

I tried to reproduce your results with a different dataset but with the same parameters, and I’m seeing that the two models have the same validation ELBO and latent representation.

import numpy as np
import scvi

scvi.settings.seed = 9627

adata = scvi.data.heart_cell_atlas_subsampled(save_path="./data")
early_stopping_kwargs = {
    "early_stopping": True,
    "early_stopping_monitor": "elbo_validation",
    "early_stopping_patience": 10,
    "early_stopping_min_delta": 0.001,
}

plan_kwargs = {
    "reduce_lr_on_plateau": False,
    "lr_patience": 10,
    'lr': 0.0011
}
model_kwargs = {
    "use_layer_norm": "both",
    "use_batch_norm": "none",
    "encode_covariates": True,
    "dropout_rate": 0.2,
    "gene_likelihood": "nb",
    "deeply_inject_covariates": False,
    "n_layers": 2,
    "n_hidden": 128,
    "n_latent": 20,
    "dispersion": "gene-cell",
}
scvi.model.SCVI.setup_anndata(adata, batch_key="cell_source")

models = []
for i in range(2):
    scvi.settings.seed = 9627
    model = scvi.model.SCVI(adata, **model_kwargs)
    model.train(
        max_epochs=500,
        batch_size=1024,
        check_val_every_n_epoch=1,
        plan_kwargs=plan_kwargs,
        **early_stopping_kwargs,
    )
    models.append(model)

assert all(models[0].history["elbo_validation"] == models[1].history["elbo_validation"])
assert np.array_equal(models[0].get_latent_representation(), models[1].get_latent_representation())
assert np.array_equal(
    models[0].module.decoder.px_decoder.fc_layers[0][0].weight.detach().cpu().numpy(),
    models[1].module.decoder.px_decoder.fc_layers[0][0].weight.detach().cpu().numpy(),
)

For reference, this is using the tip of the main branch (which is essentially the same as 1.0.3) with PyTorch 2.0.1 and Lightning 2.0.8, running on CUDA.

Hi,
I ran into a similar issue and found that if the scvi.settings.seed = 9627 line is at the start of a script then I get different outputs but if I set scvi.settings.seed = 9627 right before model = scvi.model.SCVI(adata) then it replicates.

Can you share your full notebook or code? I don’t think this is not a problem with recent versions of scVI-tools. The problem described above was Jax specific and is fixed in our main branch.
I guess, what might happen here: You import another package that resets the seed or part of it (like numpy seed).In your case it’s safe to just set the seed before running train.