Error in Fine-tuning scANVI Hyperparameters from Pretrained scVI Model

Hello everyone,

I am working with a single-cell RNA-seq dataset of approximately 48,000 cells from the Tabula Muris datasets and an in-house dataset. My goal is to integrate these datasets for label transfer from the reference to our query dataset using scANVI.

After successfully fine-tuning the scVI hyperparameters with “autotune” (ray 2.7.0), I tried to fine-tune the scANVI hyperparameters. I focused on n_layer, n_latent and gene_likelihood starting from the pretrained scVI model using from_scvi_model method.

Does it make sense to fine-tune the hyperparameters of the scANVI model after scVI, or the hyperparameters are inherited from the scVI model without the possibility of changing them?

I used the same steps to fine-tune scANVI, however, during the process, I encountered the following error when running the ModelTuner.fit method:

adata = sc.read_h5ad("../../results/scvi/scvi_harmonized_adata.h5ad")

scvi_model = scvi.model.SCVI.load("./scvi_model_v2/", adata)

model_cls = scvi.model.SCANVI
model_cls.from_scvi_model(scvi_model, adata=adata, unlabeled_category="Unknown", labels_key=SCANVI_CELLTYPE_KEY)

scvi_tuner = autotune.ModelTuner(model_cls)

search_space = {
    "n_layers": tune.choice([1, 2]),
    "n_latent": tune.choice([10, 30]),
    "gene_likelihood": tune.choice(['zinb', 'nb']),
}

ray.init(log_to_driver=False)
results = scvi_tuner.fit(
    adata,
    metric="validation_loss",
    search_space=search_space,
    num_samples=4,
    max_epochs=100,
    resources={"cpu": 10, "gpu": 0},
)
------------------------------------------------------------------------------------------------------
2023-11-17 18:36:44,234	ERROR tune_controller.py:1502 -- Trial task failed for trial _trainable_6e9e0c8e
Traceback (most recent call last):
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/_private/worker.py", line 2547, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::ImplicitFunc.train() (pid=3481600, ip=10.93.31.100, actor_id=cd6a74f10b1c3ba8f2926e3101000000, repr=_trainable)
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 400, in train
    raise skipped from exception_cause(skipped)
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/air/_internal/util.py", line 91, in run
    self._ret = self._target(*self._args, **self._kwargs)
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/tune/trainable/function_trainable.py", line 383, in <lambda>
    training_func=lambda: self._trainable_func(self.config),
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/tune/trainable/function_trainable.py", line 822, in _trainable_func
    output = fn()
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/ray/tune/trainable/util.py", line 321, in inner
    return trainable(config, **fn_kwargs)
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/scvi/autotune/_manager.py", line 400, in _trainable
    model.train(
  File "/home/pameslin/miniconda3/envs/singlecell_env/lib/python3.9/site-packages/scvi/model/_scanvi.py", line 424, in train
    trainer_kwargs["callbacks"].concatenate(sampler_callback)
AttributeError: 'list' object has no attribute 'concatenate'

Is it a version incompatibility issue?

Has anyone encountered a similar issue, or can provide insights into resolving this error? Additionally, I have not found any information on this error.

Updating ray to 2.8.0 and scvi-tools to 1.0.4 does not resolve this issue.

And configure the model with setup_anndata instead of from_scvi_model either.

model_cls = scvi.model.SCANVI
model_cls.setup_anndata(adata, unlabeled_category="Unknown", labels_key=SCANVI_CELLTYPE_KEY)

Hi, thank you for your question. We recently caught this bug and was fixed in this PR, which will be released in scvi-tools 1.1. Not sure how we didn’t catch this before.

For now, I would recommend running hyperparameter optimization on just scVI, training scVI on those parameters, and then initializing scANVI from the trained scVI model. This should hopefully bypass the error you’re seeing.