Error on scvi_model.train(100) for seed label transfer

I have been using seed label transfer method from this tutorial Seed labeling with scANVI — scvi-tools.

When I reach to this step scvi_model.train(100) , I got the error below. I checked my data and I don’t have any nan, anyone know what’s the issue could be.
Thanks

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [18], in <module>
----> 1 scvi_model.train(100)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/model/base/_training_mixin.py:77, in UnsupervisedTrainingMixin.train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     66 trainer_kwargs[es] = (
     67     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
     68 )
     69 runner = TrainRunner(
     70     self,
     71     training_plan=training_plan,
   (...)
     75     **trainer_kwargs,
     76 )
---> 77 return runner()

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/train/_trainrunner.py:72, in TrainRunner.__call__(self)
     69 if hasattr(self.data_splitter, "n_train"):
     70     self.training_plan.n_obs_training = self.data_splitter.n_train
---> 72 self.trainer.fit(self.training_plan, self.data_splitter)
     73 self._update_history()
     75 # data splitter only gets these attrs after fit

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/train/_trainer.py:177, in Trainer.fit(self, *args, **kwargs)
    171 if isinstance(args[0], PyroTrainingPlan):
    172     warnings.filterwarnings(
    173         action="ignore",
    174         category=UserWarning,
    175         message="`LightningModule.configure_optimizers` returned `None`",
    176     )
--> 177 super().fit(*args, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:460, in Trainer.fit(self, model, train_dataloader, val_dataloaders, datamodule)
    455 # links data to the trainer
    456 self.data_connector.attach_data(
    457     model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
    458 )
--> 460 self._run(model)
    462 assert self.state.stopped
    463 self.training = False

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:758, in Trainer._run(self, model)
    755 self.pre_dispatch()
    757 # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758 self.dispatch()
    760 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
    761 self.post_dispatch()

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:799, in Trainer.dispatch(self)
    797     self.accelerator.start_predicting(self)
    798 else:
--> 799     self.accelerator.start_training(self)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py:96, in Accelerator.start_training(self, trainer)
     95 def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96     self.training_type_plugin.start_training(trainer)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:144, in TrainingTypePlugin.start_training(self, trainer)
    142 def start_training(self, trainer: 'pl.Trainer') -> None:
    143     # double dispatch to initiate the training loop
--> 144     self._results = trainer.run_stage()

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:809, in Trainer.run_stage(self)
    807 if self.predicting:
    808     return self.run_predict()
--> 809 return self.run_train()

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:871, in Trainer.run_train(self)
    867 self.train_loop.on_train_epoch_start(epoch)
    869 with self.profiler.profile("run_training_epoch"):
    870     # run train epoch
--> 871     self.train_loop.run_training_epoch()
    873 if self.max_steps and self.max_steps <= self.global_step:
    874     self.train_loop.on_train_end()

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py:499, in TrainLoop.run_training_epoch(self)
    495 # ------------------------------------
    496 # TRAINING_STEP + TRAINING_STEP_END
    497 # ------------------------------------
    498 with self.trainer.profiler.profile("run_training_batch"):
--> 499     batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    501 # when returning -1 from train_step, we end epoch early
    502 if batch_output.signal == -1:

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py:738, in TrainLoop.run_training_batch(self, batch, batch_idx, dataloader_idx)
    735     return None if result is None else result.loss
    737 # optimizer step
--> 738 self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    739 if len(self.trainer.optimizers) > 1:
    740     # revert back to previous state
    741     self.trainer.lightning_module.untoggle_optimizer(opt_idx)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py:434, in TrainLoop.optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    431 optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
    433 # model hook
--> 434 model_ref.optimizer_step(
    435     self.trainer.current_epoch,
    436     batch_idx,
    437     optimizer,
    438     opt_idx,
    439     train_step_and_backward_closure,
    440     on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE,
    441     using_native_amp=using_native_amp,
    442     using_lbfgs=is_lbfgs,
    443 )

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py:1403, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1330 def optimizer_step(
   1331     self,
   1332     epoch: int = None,
   (...)
   1339     using_lbfgs: bool = None,
   1340 ) -> None:
   1341     r"""
   1342     Override this method to adjust the default way the
   1343     :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
   (...)
   1401 
   1402     """
-> 1403     optimizer.step(closure=optimizer_closure)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py:214, in LightningOptimizer.step(self, closure, *args, **kwargs)
    211         raise MisconfigurationException("When closure is provided, it should be a function")
    212     profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
--> 214 self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215 self._total_optimizer_step_calls += 1

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py:134, in LightningOptimizer.__optimizer_step(self, closure, profiler_name, **kwargs)
    131 optimizer = self._optimizer
    133 with trainer.profiler.profile(profiler_name):
--> 134     trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py:329, in Accelerator.optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    325 make_optimizer_step = self.precision_plugin.pre_optimizer_step(
    326     self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs
    327 )
    328 if make_optimizer_step:
--> 329     self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    330 self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    331 self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py:336, in Accelerator.run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    333 def run_optimizer_step(
    334     self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
    335 ) -> None:
--> 336     self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:193, in TrainingTypePlugin.optimizer_step(self, optimizer, lambda_closure, **kwargs)
    192 def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 193     optimizer.step(closure=lambda_closure, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/torch/optim/optimizer.py:88, in Optimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper(*args, **kwargs)
     86 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87 with torch.autograd.profiler.record_function(profile_name):
---> 88     return func(*args, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/torch/autograd/grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/torch/optim/adam.py:92, in Adam.step(self, closure)
     90 if closure is not None:
     91     with torch.enable_grad():
---> 92         loss = closure()
     94 for group in self.param_groups:
     95     params_with_grad = []

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py:732, in TrainLoop.run_training_batch.<locals>.train_step_and_backward_closure()
    731 def train_step_and_backward_closure():
--> 732     result = self.training_step_and_backward(
    733         split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    734     )
    735     return None if result is None else result.loss

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py:823, in TrainLoop.training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    820 """Wrap forward, zero_grad and backward in a closure so second order methods work"""
    821 with self.trainer.profiler.profile("training_step_and_backward"):
    822     # lightning module hook
--> 823     result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    824     self._curr_step_result = result
    826     if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py:290, in TrainLoop.training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    288 model_ref._results = Result()
    289 with self.trainer.profiler.profile("training_step"):
--> 290     training_step_output = self.trainer.accelerator.training_step(args)
    291     self.trainer.accelerator.post_training_step()
    293 self.trainer.logger_connector.cache_logged_metrics()

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py:204, in Accelerator.training_step(self, args)
    201 args[0] = self.to_device(args[0])
    203 with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 204     return self.training_type_plugin.training_step(*args)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py:155, in TrainingTypePlugin.training_step(self, *args, **kwargs)
    154 def training_step(self, *args, **kwargs):
--> 155     return self.lightning_module.training_step(*args, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/train/_trainingplans.py:152, in TrainingPlan.training_step(self, batch, batch_idx, optimizer_idx)
    150 if "kl_weight" in self.loss_kwargs:
    151     self.loss_kwargs.update({"kl_weight": self.kl_weight})
--> 152 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
    153 reconstruction_loss = scvi_loss.reconstruction_loss
    154 # pytorch lightning automatically backprops on "loss"

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/train/_trainingplans.py:147, in TrainingPlan.forward(self, *args, **kwargs)
    145 def forward(self, *args, **kwargs):
    146     """Passthrough to `model.forward()`."""
--> 147     return self.module(*args, **kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list(set(p.device for p in self.parameters()))
     35 if len(device) > 1:

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/module/base/_base_module.py:145, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    140 get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs)
    142 inference_inputs = self._get_inference_input(
    143     tensors, **get_inference_input_kwargs
    144 )
--> 145 inference_outputs = self.inference(**inference_inputs, **inference_kwargs)
    146 generative_inputs = self._get_generative_input(
    147     tensors, inference_outputs, **get_generative_input_kwargs
    148 )
    149 generative_outputs = self.generative(**generative_inputs, **generative_kwargs)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list(set(p.device for p in self.parameters()))
     35 if len(device) > 1:

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/module/_vae.py:276, in VAE.inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    274 else:
    275     categorical_input = tuple()
--> 276 qz_m, qz_v, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
    278 ql_m, ql_v = None, None
    279 if not self.use_observed_lib_size:

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/nn/_base_components.py:294, in Encoder.forward(self, x, *cat_list)
    292 q_m = self.mean_encoder(q)
    293 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 294 latent = self.z_transformation(reparameterize_gaussian(q_m, q_v))
    295 return q_m, q_v, latent

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/scvi/nn/_base_components.py:13, in reparameterize_gaussian(mu, var)
     12 def reparameterize_gaussian(mu, var):
---> 13     return Normal(mu, var.sqrt()).rsample()

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/torch/distributions/normal.py:50, in Normal.__init__(self, loc, scale, validate_args)
     48 else:
     49     batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)

File /opt/conda/envs/scvi-env/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     53 valid = constraint.check(value)
     54 if not valid.all():
---> 55     raise ValueError(
     56         f"Expected parameter {param} "
     57         f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58         f"of distribution {repr(self)} "
     59         f"to satisfy the constraint {repr(constraint)}, "
     60         f"but found invalid values:\n{value}"
     61     )
     62 if not constraint.check(getattr(self, param)).all():
     63     raise ValueError("The parameter {} has invalid values".format(param))

ValueError: Expected parameter loc (Tensor of shape (128, 30)) of distribution Normal(loc: torch.Size([128, 30]), scale: torch.Size([128, 30])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

It seems you have faced the problem of exploding gradients (which has been mentioned elsewhere: DestVI Tensor Nan Error). Let’s use smaller learning_rate to see if it can help troubleshoot the problem.
Best,