Error when using scVI.model.train

Hi there, I am trying to integrate my data using scVI.

I have tuned the model hyperparameters with ray.
I am subsetting my object to highly variable genes

Here is the code:

# set up the adata object
scvi.model.SCVI.setup_anndata(adata_scvi,
                              layer = 'counts',
                              batch_key = "patient.seqbatch"
                             )

# set up the scVI model
# specify the parameters according to the ray optimisation
vae = scvi.model.SCVI(adata_scvi,
                      n_latent = 30,  
                      n_hidden = 60, 
                      n_layers =  1, 
                      dropout_rate = 0.1,
                      dispersion = 'gene-batch',
                     gene_likelihood='zinb')

# add the parameter for lr
kwargs = {'lr': 0.0023}

# calculate the number of epochs needed which varies according to the cell number
max_epochs_scvi = np.min([round((20000 / adata_ref.n_obs) * 400), 400])
max_epochs_scvi

# run the training (need the GPU queue here)
vae.train(max_epochs = max_epochs_scvi, 
          train_size = 0.9, 
          validation_size = 0.1, 
          accelerator='gpu', 
          check_val_every_n_epoch=1,
          early_stopping=True,
          early_stopping_patience=10,
          early_stopping_monitor="elbo_validation",
          plan_kwargs = kwargs
         )

Output:

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[49], line 4
      1 # run the training (need the GPU queue here)
      2 # %%time # will print the amount of CPU/GPU time taken for this chunk
----> 4 vae.train(max_epochs = max_epochs_scvi, 
      5           train_size = 0.9, 
      6           validation_size = 0.1, 
      7           accelerator='gpu', 
      8           check_val_every_n_epoch=1,
      9           early_stopping=True,
     10           early_stopping_patience=10,
     11           early_stopping_monitor="elbo_validation",
     12           plan_kwargs = kwargs
     13          )

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/scvi/model/base/_training_mixin.py:88, in UnsupervisedTrainingMixin.train(self, max_epochs, use_gpu, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     75 trainer_kwargs[es] = (
     76     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
     77 )
     78 runner = self._train_runner_cls(
     79     self,
     80     training_plan=training_plan,
   (...)
     86     **trainer_kwargs,
     87 )
---> 88 return runner()

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/scvi/train/_trainrunner.py:99, in TrainRunner.__call__(self)
     96 if hasattr(self.data_splitter, "n_val"):
     97     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 99 self.trainer.fit(self.training_plan, self.data_splitter)
    100 self._update_history()
    102 # data splitter only gets these attrs after fit

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     41     if trainer.strategy.launcher is not None:
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43     return trainer_fn(*args, **kwargs)
     45 except _TunerExitException:
     46     _call_teardown_hook(trainer)

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:571, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    561 self._data_connector.attach_data(
    562     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    563 )
    565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    566     self.state.fn,
    567     ckpt_path,
    568     model_provided=True,
    569     model_connected=self.lightning_module is not None,
    570 )
--> 571 self._run(model, ckpt_path=ckpt_path)
    573 assert self.state.stopped
    574 self.training = False

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:980, in Trainer._run(self, model, ckpt_path)
    975 self._signal_connector.register_signal_handlers()
    977 # ----------------------------
    978 # RUN THE TRAINER
    979 # ----------------------------
--> 980 results = self._run_stage()
    982 # ----------------------------
    983 # POST-Training CLEAN UP
    984 # ----------------------------
    985 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1023, in Trainer._run_stage(self)
   1021         self._run_sanity_check()
   1022     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1023         self.fit_loop.run()
   1024     return None
   1025 raise RuntimeError(f"Unexpected state {self.state}")

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:195, in _FitLoop.run(self)
    193 def run(self) -> None:
    194     self.setup_data()
--> 195     if self.skip:
    196         return
    197     self.reset()

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:191, in _FitLoop.skip(self)
    188 """Whether we should skip the training and immediately return from the call to :meth:`run`."""
    189 # if `limit_train_batches == 0` then `setup_data` won't set the `self.max_batches` attribute (checked in `done`)
    190 # so we cannot use it solely
--> 191 return self.done or self.trainer.limit_train_batches == 0

File /media/prom/apc1/ccohen/mamba_installation/conda/envs/scvi_env/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:172, in _FitLoop.done(self)
    168     return True
    170 # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
    171 # we use it here because the checkpoint data won't have `completed` increased yet
--> 172 assert isinstance(self.max_epochs, int)
    173 stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
    174 if stop_epochs:
    175     # in case they are not equal, override so `trainer.current_epoch` has the expected value

AssertionError: 

Any help would be greatly appreciated. Thanks, Carla

Hi, you have to cast max_epochs_scVI to integer.

Many thanks, a quick fix!