Tensor nan error when training scANVI model

Hi, I ran into tensor nan error when using scANVI set up with scVI model, hope someone can help me~

  1. setup anndata with
scvi.model.SCVI.setup_anndata(combined_adata, layer = "counts",
                             categorical_covariate_keys=["Batch"],
                             continuous_covariate_keys=['pct_counts_mt', 'total_counts'])
  1. hyperparameter tuning:
search_space = {
    "n_hidden": tune.choice([128, 256, 512, 1024]),
    "n_layers": tune.choice([1, 2, 3, 4, 5]),
    "lr": tune.loguniform(1e-4, 1e-2),
    "n_latent":tune.choice([10, 15, 20, 30, 50]),
    "gene_likelihood": tune.choice(["nb", "zinb"]),
}
# ray.init(log_to_driver=False)
results = scvi_tuner.fit(
    combined_adata,
    metric="validation_loss",
    search_space=search_space,
    num_samples=60,
    max_epochs=35,
    resources={"cpu": 40, "gpu": 2,"memory":4e11 },
)

tensor nan error occur for the first time here when tuning

_trainable_bb90b45b	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_bb90b45b_3_gene_likelihood=nb,lr=0.0035,n_hidden=1024,n_latent=50,n_layers=4_2023-08-10_22-26-49/error.txt
_trainable_13012238	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_13012238_4_gene_likelihood=nb,lr=0.0059,n_hidden=1024,n_latent=50,n_layers=5_2023-08-10_23-18-34/error.txt
_trainable_fc1efd50	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_fc1efd50_8_gene_likelihood=zinb,lr=0.0094,n_hidden=128,n_latent=50,n_layers=5_2023-08-11_00-05-04/error.txt
_trainable_b0548edc	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_b0548edc_15_gene_likelihood=nb,lr=0.0055,n_hidden=512,n_latent=50,n_layers=5_2023-08-11_01-14-20/error.txt
_trainable_386f6dab	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_386f6dab_20_gene_likelihood=nb,lr=0.0096,n_hidden=128,n_latent=50,n_layers=3_2023-08-11_03-07-35/error.txt
_trainable_2e31bdad	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_2e31bdad_33_gene_likelihood=zinb,lr=0.0031,n_hidden=1024,n_latent=30,n_layers=2_2023-08-11_04-40-16/error.txt
_trainable_f4800d7e	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_f4800d7e_35_gene_likelihood=zinb,lr=0.0029,n_hidden=1024,n_latent=30,n_layers=2_2023-08-11_05-26-24/error.txt
_trainable_4a870bf4	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_4a870bf4_45_gene_likelihood=zinb,lr=0.0018,n_hidden=1024,n_latent=30,n_layers=2_2023-08-11_08-23-46/error.txt
_trainable_1f338d31	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_1f338d31_55_gene_likelihood=nb,lr=0.0045,n_hidden=1024,n_latent=50,n_layers=4_2023-08-11_08-50-35/error.txt
_trainable_41b7e0bf	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_41b7e0bf_59_gene_likelihood=zinb,lr=0.0059,n_hidden=256,n_latent=20,n_layers=5_2023-08-11_09-17-35/error.txt
  1. There’re some combinations of the parameters did finish testing, so I managed to fit the model with parameter returned from ray anyway.
  2. setup scANVI with scVI:
lvae = scvi.model.SCANVI.from_scvi_model(scvi_model, adata = combined_adata, unlabeled_category = 'Unknown',
                                        labels_key = 'ct')

lvae.train(max_epochs=100, n_samples_per_label=100,early_stopping=True,)

at fisrt, it could proceed but stop at 78th epoch

Epoch 78/100:  77%|███████▋  | 77/100 [3:40:47<1:05:18, 170.36s/it, v_num=1, train_loss_step=1.63e+3, train_loss_epoch=1.22e+3]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[58], line 5
      1 # 0811
      2 lvae = scvi.model.SCANVI.from_scvi_model(scvi_model, adata = combined_adata, unlabeled_category = 'Unknown',
      3                                         labels_key = 'ct')
----> 5 lvae.train(max_epochs=100, n_samples_per_label=100,early_stopping=True,)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/model/_scanvi.py:439, in SCANVI.train(self, max_epochs, n_samples_per_label, check_val_every_n_epoch, train_size, validation_size, shuffle_set_split, batch_size, use_gpu, accelerator, devices, plan_kwargs, **trainer_kwargs)
    426     trainer_kwargs["callbacks"] = sampler_callback
    428 runner = TrainRunner(
    429     self,
    430     training_plan=training_plan,
   (...)
    437     **trainer_kwargs,
    438 )
--> 439 return runner()

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainrunner.py:99, in TrainRunner.__call__(self)
     96 if hasattr(self.data_splitter, "n_val"):
     97     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 99 self.trainer.fit(self.training_plan, self.data_splitter)
    100 self._update_history()
    102 # data splitter only gets these attrs after fit

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:520, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    518 model = _maybe_unwrap_optimized(model)
    519 self.strategy._lightning_module = model
--> 520 call._call_and_handle_interrupt(
    521     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    522 )

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     43     else:
---> 44         return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:559, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    549 self._data_connector.attach_data(
    550     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    551 )
    553 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    554     self.state.fn,
    555     ckpt_path,
    556     model_provided=True,
    557     model_connected=self.lightning_module is not None,
    558 )
--> 559 self._run(model, ckpt_path=ckpt_path)
    561 assert self.state.stopped
    562 self.training = False

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:935, in Trainer._run(self, model, ckpt_path)
    930 self._signal_connector.register_signal_handlers()
    932 # ----------------------------
    933 # RUN THE TRAINER
    934 # ----------------------------
--> 935 results = self._run_stage()
    937 # ----------------------------
    938 # POST-Training CLEAN UP
    939 # ----------------------------
    940 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:978, in Trainer._run_stage(self)
    976         self._run_sanity_check()
    977     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
--> 978         self.fit_loop.run()
    979     return None
    980 raise RuntimeError(f"Unexpected state {self.state}")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:201, in _FitLoop.run(self)
    199 try:
    200     self.on_advance_start()
--> 201     self.advance()
    202     self.on_advance_end()
    203     self._restarting = False

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:354, in _FitLoop.advance(self)
    352 self._data_fetcher.setup(combined_loader)
    353 with self.trainer.profiler.profile("run_training_epoch"):
--> 354     self.epoch_loop.run(self._data_fetcher)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:133, in _TrainingEpochLoop.run(self, data_fetcher)
    131 while not self.done:
    132     try:
--> 133         self.advance(data_fetcher)
    134         self.on_advance_end()
    135         self._restarting = False

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:218, in _TrainingEpochLoop.advance(self, data_fetcher)
    215 with trainer.profiler.profile("run_training_batch"):
    216     if trainer.lightning_module.automatic_optimization:
    217         # in automatic optimization, there can only be one optimizer
--> 218         batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
    219     else:
    220         batch_output = self.manual_optimization.run(kwargs)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:185, in _AutomaticOptimization.run(self, optimizer, kwargs)
    178         closure()
    180 # ------------------------------
    181 # BACKWARD PASS
    182 # ------------------------------
    183 # gradient update with accumulated gradients
    184 else:
--> 185     self._optimizer_step(kwargs.get("batch_idx", 0), closure)
    187 result = closure.consume_result()
    188 if result.loss is None:

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:261, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    258     self.optim_progress.optimizer.step.increment_ready()
    260 # model hook
--> 261 call._call_lightning_module_hook(
    262     trainer,
    263     "optimizer_step",
    264     trainer.current_epoch,
    265     batch_idx,
    266     optimizer,
    267     train_step_and_backward_closure,
    268 )
    270 if not should_accumulate:
    271     self.optim_progress.optimizer.step.increment_completed()

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:142, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    139 pl_module._current_fx_name = hook_name
    141 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 142     output = fn(*args, **kwargs)
    144 # restore current_fx when nested context
    145 pl_module._current_fx_name = prev_fx_name

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1265, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1226 def optimizer_step(
   1227     self,
   1228     epoch: int,
   (...)
   1231     optimizer_closure: Optional[Callable[[], Any]] = None,
   1232 ) -> None:
   1233     r"""
   1234     Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1235     the optimizer.
   (...)
   1263                     pg["lr"] = lr_scale * self.learning_rate
   1264     """
-> 1265     optimizer.step(closure=optimizer_closure)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:158, in LightningOptimizer.step(self, closure, **kwargs)
    155     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    157 assert self._strategy is not None
--> 158 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    160 self._on_after_step()
    162 return step_output

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:224, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    222 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    223 assert isinstance(model, pl.LightningModule)
--> 224 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:114, in PrecisionPlugin.optimizer_step(self, optimizer, model, closure, **kwargs)
    112 """Hook to run the optimizer step."""
    113 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 114 return optimizer.step(closure=closure, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:280, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    276         else:
    277             raise RuntimeError(f"{func} must return None or a tuple of (new_args, new_kwargs),"
    278                                f"but got {result}.")
--> 280 out = func(*args, **kwargs)
    281 self._optimizer_step_code()
    283 # call optimizer step post hooks

File ~/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:33, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     31 try:
     32     torch.set_grad_enabled(self.defaults['differentiable'])
---> 33     ret = func(self, *args, **kwargs)
     34 finally:
     35     torch.set_grad_enabled(prev_grad)

File ~/.local/lib/python3.10/site-packages/torch/optim/adam.py:121, in Adam.step(self, closure)
    119 if closure is not None:
    120     with torch.enable_grad():
--> 121         loss = closure()
    123 for group in self.param_groups:
    124     params_with_grad = []

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:101, in PrecisionPlugin._wrap_closure(self, model, optimizer, closure)
     89 def _wrap_closure(
     90     self,
     91     model: "pl.LightningModule",
     92     optimizer: Optimizer,
     93     closure: Callable[[], Any],
     94 ) -> Any:
     95     """This double-closure allows makes sure the ``closure`` is executed before the
     96     ``on_before_optimizer_step`` hook is called.
     97 
     98     The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
     99     consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    100     """
--> 101     closure_result = closure()
    102     self._after_closure(model, optimizer)
    103     return closure_result

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
    125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126     step_output = self._step_fn()
    128     if step_output.closure_loss is None:
    129         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:308, in _AutomaticOptimization._training_step(self, kwargs)
    305 trainer = self.trainer
    307 # manually capture logged metrics
--> 308 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    309 self.trainer.strategy.post_training_step()
    311 result = self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:288, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    285     return
    287 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 288     output = fn(*args, **kwargs)
    290 # restore current_fx when nested context
    291 pl_module._current_fx_name = prev_fx_name

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:366, in Strategy.training_step(self, *args, **kwargs)
    364 with self.precision_plugin.train_step_context():
    365     assert isinstance(self.model, TrainingStep)
--> 366     return self.model.training_step(*args, **kwargs)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainingplans.py:797, in SemiSupervisedTrainingPlan.training_step(self, batch, batch_idx)
    792 input_kwargs = {
    793     "feed_labels": False,
    794     "labelled_tensors": labelled_dataset,
    795 }
    796 input_kwargs.update(self.loss_kwargs)
--> 797 _, _, loss_output = self.forward(full_dataset, loss_kwargs=input_kwargs)
    798 loss = loss_output.loss
    799 self.log(
    800     "train_loss",
    801     loss,
   (...)
    804     prog_bar=True,
    805 )

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainingplans.py:278, in TrainingPlan.forward(self, *args, **kwargs)
    276 def forward(self, *args, **kwargs):
    277     """Passthrough to the module's forward method."""
--> 278     return self.module(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_base_module.py:199, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    165 @auto_move_data
    166 def forward(
    167     self,
   (...)
    177     | tuple[torch.Tensor, torch.Tensor, LossOutput]
    178 ):
    179     """Forward pass through the network.
    180 
    181     Parameters
   (...)
    197         another return value.
    198     """
--> 199     return _generic_forward(
    200         self,
    201         tensors,
    202         inference_kwargs,
    203         generative_kwargs,
    204         loss_kwargs,
    205         get_inference_input_kwargs,
    206         get_generative_input_kwargs,
    207         compute_loss,
    208     )

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_base_module.py:743, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    738 get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs)
    740 inference_inputs = module._get_inference_input(
    741     tensors, **get_inference_input_kwargs
    742 )
--> 743 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    744 generative_inputs = module._get_generative_input(
    745     tensors, inference_outputs, **get_generative_input_kwargs
    746 )
    747 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_base_module.py:303, in BaseMinifiedModeModuleClass.inference(self, *args, **kwargs)
    301     return self._cached_inference(*args, **kwargs)
    302 else:
--> 303     return self._regular_inference(*args, **kwargs)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/_vae.py:336, in VAE._regular_inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    334 else:
    335     categorical_input = ()
--> 336 qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
    337 ql = None
    338 if not self.use_observed_lib_size:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/nn/_base_components.py:289, in Encoder.forward(self, x, *cat_list)
    287 q_m = self.mean_encoder(q)
    288 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 289 dist = Normal(q_m, q_v.sqrt())
    290 latent = self.z_transformation(dist.rsample())
    291 if self.return_dist:

File ~/.local/lib/python3.10/site-packages/torch/distributions/normal.py:56, in Normal.__init__(self, loc, scale, validate_args)
     54 else:
     55     batch_shape = self.loc.size()
---> 56 super().__init__(batch_shape, validate_args=validate_args)

File ~/.local/lib/python3.10/site-packages/torch/distributions/distribution.py:62, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     60         valid = constraint.check(value)
     61         if not valid.all():
---> 62             raise ValueError(
     63                 f"Expected parameter {param} "
     64                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     65                 f"of distribution {repr(self)} "
     66                 f"to satisfy the constraint {repr(constraint)}, "
     67                 f"but found invalid values:\n{value}"
     68             )
     69 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (128, 30)) of distribution Normal(loc: torch.Size([128, 30]), scale: torch.Size([128, 30])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

Then, I try to re-run without changing the parameter, or I set lr to 0.0001, it didn’t fish the 1st epoch, with the same tensor nan error.

running lvae.view_anndata_setup(combined_adata):

Anndata setup with scvi-tools version 1.0.2.
Setup via `SCANVI.setup_anndata` with arguments:
{
│   'labels_key': 'ct',
│   'unlabeled_category': 'Unknown',
│   'layer': 'counts',
│   'batch_key': None,
│   'size_factor_key': None,
│   'categorical_covariate_keys': ['Batch'],
│   'continuous_covariate_keys': ['pct_counts_mt', 'total_counts']
}
         Summary Statistics          
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃     Summary Stat Key     ┃ Value  ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│         n_batch          │   1    │
│         n_cells          │ 620233 │
│ n_extra_categorical_covs │   1    │
│ n_extra_continuous_covs  │   2    │
│         n_labels         │   39   │
│          n_vars          │ 14658  │
└──────────────────────────┴────────┘
                             Data Registry                             
┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Registry Key      ┃            scvi-tools Location             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│           X            │           adata.layers['counts']           │
│         batch          │          adata.obs['_scvi_batch']          │
│ extra_categorical_covs │ adata.obsm['_scvi_extra_categorical_covs'] │
│ extra_continuous_covs  │ adata.obsm['_scvi_extra_continuous_covs']  │
│         labels         │         adata.obs['_scvi_labels']          │
└────────────────────────┴────────────────────────────────────────────┘

Thanks