Tensor nan error when training scANVI model

Hi, I ran into tensor nan error when using scANVI set up with scVI model, hope someone can help me~

  1. setup anndata with
scvi.model.SCVI.setup_anndata(combined_adata, layer = "counts",
                             categorical_covariate_keys=["Batch"],
                             continuous_covariate_keys=['pct_counts_mt', 'total_counts'])
  1. hyperparameter tuning:
search_space = {
    "n_hidden": tune.choice([128, 256, 512, 1024]),
    "n_layers": tune.choice([1, 2, 3, 4, 5]),
    "lr": tune.loguniform(1e-4, 1e-2),
    "n_latent":tune.choice([10, 15, 20, 30, 50]),
    "gene_likelihood": tune.choice(["nb", "zinb"]),
}
# ray.init(log_to_driver=False)
results = scvi_tuner.fit(
    combined_adata,
    metric="validation_loss",
    search_space=search_space,
    num_samples=60,
    max_epochs=35,
    resources={"cpu": 40, "gpu": 2,"memory":4e11 },
)

tensor nan error occur for the first time here when tuning

_trainable_bb90b45b	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_bb90b45b_3_gene_likelihood=nb,lr=0.0035,n_hidden=1024,n_latent=50,n_layers=4_2023-08-10_22-26-49/error.txt
_trainable_13012238	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_13012238_4_gene_likelihood=nb,lr=0.0059,n_hidden=1024,n_latent=50,n_layers=5_2023-08-10_23-18-34/error.txt
_trainable_fc1efd50	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_fc1efd50_8_gene_likelihood=zinb,lr=0.0094,n_hidden=128,n_latent=50,n_layers=5_2023-08-11_00-05-04/error.txt
_trainable_b0548edc	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_b0548edc_15_gene_likelihood=nb,lr=0.0055,n_hidden=512,n_latent=50,n_layers=5_2023-08-11_01-14-20/error.txt
_trainable_386f6dab	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_386f6dab_20_gene_likelihood=nb,lr=0.0096,n_hidden=128,n_latent=50,n_layers=3_2023-08-11_03-07-35/error.txt
_trainable_2e31bdad	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_2e31bdad_33_gene_likelihood=zinb,lr=0.0031,n_hidden=1024,n_latent=30,n_layers=2_2023-08-11_04-40-16/error.txt
_trainable_f4800d7e	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_f4800d7e_35_gene_likelihood=zinb,lr=0.0029,n_hidden=1024,n_latent=30,n_layers=2_2023-08-11_05-26-24/error.txt
_trainable_4a870bf4	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_4a870bf4_45_gene_likelihood=zinb,lr=0.0018,n_hidden=1024,n_latent=30,n_layers=2_2023-08-11_08-23-46/error.txt
_trainable_1f338d31	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_1f338d31_55_gene_likelihood=nb,lr=0.0045,n_hidden=1024,n_latent=50,n_layers=4_2023-08-11_08-50-35/error.txt
_trainable_41b7e0bf	1	/mnt/data2/2019_cell/data/0810/ray/tune_scvi_2023-08-10-21:27:53/_trainable_41b7e0bf_59_gene_likelihood=zinb,lr=0.0059,n_hidden=256,n_latent=20,n_layers=5_2023-08-11_09-17-35/error.txt
  1. There’re some combinations of the parameters did finish testing, so I managed to fit the model with parameter returned from ray anyway.
  2. setup scANVI with scVI:
lvae = scvi.model.SCANVI.from_scvi_model(scvi_model, adata = combined_adata, unlabeled_category = 'Unknown',
                                        labels_key = 'ct')

lvae.train(max_epochs=100, n_samples_per_label=100,early_stopping=True,)

at fisrt, it could proceed but stop at 78th epoch

Epoch 78/100:  77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 77/100 [3:40:47<1:05:18, 170.36s/it, v_num=1, train_loss_step=1.63e+3, train_loss_epoch=1.22e+3]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[58], line 5
      1 # 0811
      2 lvae = scvi.model.SCANVI.from_scvi_model(scvi_model, adata = combined_adata, unlabeled_category = 'Unknown',
      3                                         labels_key = 'ct')
----> 5 lvae.train(max_epochs=100, n_samples_per_label=100,early_stopping=True,)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/model/_scanvi.py:439, in SCANVI.train(self, max_epochs, n_samples_per_label, check_val_every_n_epoch, train_size, validation_size, shuffle_set_split, batch_size, use_gpu, accelerator, devices, plan_kwargs, **trainer_kwargs)
    426     trainer_kwargs["callbacks"] = sampler_callback
    428 runner = TrainRunner(
    429     self,
    430     training_plan=training_plan,
   (...)
    437     **trainer_kwargs,
    438 )
--> 439 return runner()

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainrunner.py:99, in TrainRunner.__call__(self)
     96 if hasattr(self.data_splitter, "n_val"):
     97     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 99 self.trainer.fit(self.training_plan, self.data_splitter)
    100 self._update_history()
    102 # data splitter only gets these attrs after fit

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:520, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    518 model = _maybe_unwrap_optimized(model)
    519 self.strategy._lightning_module = model
--> 520 call._call_and_handle_interrupt(
    521     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    522 )

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     43     else:
---> 44         return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:559, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    549 self._data_connector.attach_data(
    550     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    551 )
    553 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    554     self.state.fn,
    555     ckpt_path,
    556     model_provided=True,
    557     model_connected=self.lightning_module is not None,
    558 )
--> 559 self._run(model, ckpt_path=ckpt_path)
    561 assert self.state.stopped
    562 self.training = False

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:935, in Trainer._run(self, model, ckpt_path)
    930 self._signal_connector.register_signal_handlers()
    932 # ----------------------------
    933 # RUN THE TRAINER
    934 # ----------------------------
--> 935 results = self._run_stage()
    937 # ----------------------------
    938 # POST-Training CLEAN UP
    939 # ----------------------------
    940 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:978, in Trainer._run_stage(self)
    976         self._run_sanity_check()
    977     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
--> 978         self.fit_loop.run()
    979     return None
    980 raise RuntimeError(f"Unexpected state {self.state}")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:201, in _FitLoop.run(self)
    199 try:
    200     self.on_advance_start()
--> 201     self.advance()
    202     self.on_advance_end()
    203     self._restarting = False

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:354, in _FitLoop.advance(self)
    352 self._data_fetcher.setup(combined_loader)
    353 with self.trainer.profiler.profile("run_training_epoch"):
--> 354     self.epoch_loop.run(self._data_fetcher)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:133, in _TrainingEpochLoop.run(self, data_fetcher)
    131 while not self.done:
    132     try:
--> 133         self.advance(data_fetcher)
    134         self.on_advance_end()
    135         self._restarting = False

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:218, in _TrainingEpochLoop.advance(self, data_fetcher)
    215 with trainer.profiler.profile("run_training_batch"):
    216     if trainer.lightning_module.automatic_optimization:
    217         # in automatic optimization, there can only be one optimizer
--> 218         batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
    219     else:
    220         batch_output = self.manual_optimization.run(kwargs)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:185, in _AutomaticOptimization.run(self, optimizer, kwargs)
    178         closure()
    180 # ------------------------------
    181 # BACKWARD PASS
    182 # ------------------------------
    183 # gradient update with accumulated gradients
    184 else:
--> 185     self._optimizer_step(kwargs.get("batch_idx", 0), closure)
    187 result = closure.consume_result()
    188 if result.loss is None:

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:261, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    258     self.optim_progress.optimizer.step.increment_ready()
    260 # model hook
--> 261 call._call_lightning_module_hook(
    262     trainer,
    263     "optimizer_step",
    264     trainer.current_epoch,
    265     batch_idx,
    266     optimizer,
    267     train_step_and_backward_closure,
    268 )
    270 if not should_accumulate:
    271     self.optim_progress.optimizer.step.increment_completed()

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:142, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    139 pl_module._current_fx_name = hook_name
    141 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 142     output = fn(*args, **kwargs)
    144 # restore current_fx when nested context
    145 pl_module._current_fx_name = prev_fx_name

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1265, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1226 def optimizer_step(
   1227     self,
   1228     epoch: int,
   (...)
   1231     optimizer_closure: Optional[Callable[[], Any]] = None,
   1232 ) -> None:
   1233     r"""
   1234     Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1235     the optimizer.
   (...)
   1263                     pg["lr"] = lr_scale * self.learning_rate
   1264     """
-> 1265     optimizer.step(closure=optimizer_closure)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:158, in LightningOptimizer.step(self, closure, **kwargs)
    155     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    157 assert self._strategy is not None
--> 158 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    160 self._on_after_step()
    162 return step_output

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:224, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    222 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    223 assert isinstance(model, pl.LightningModule)
--> 224 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:114, in PrecisionPlugin.optimizer_step(self, optimizer, model, closure, **kwargs)
    112 """Hook to run the optimizer step."""
    113 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 114 return optimizer.step(closure=closure, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:280, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    276         else:
    277             raise RuntimeError(f"{func} must return None or a tuple of (new_args, new_kwargs),"
    278                                f"but got {result}.")
--> 280 out = func(*args, **kwargs)
    281 self._optimizer_step_code()
    283 # call optimizer step post hooks

File ~/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:33, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     31 try:
     32     torch.set_grad_enabled(self.defaults['differentiable'])
---> 33     ret = func(self, *args, **kwargs)
     34 finally:
     35     torch.set_grad_enabled(prev_grad)

File ~/.local/lib/python3.10/site-packages/torch/optim/adam.py:121, in Adam.step(self, closure)
    119 if closure is not None:
    120     with torch.enable_grad():
--> 121         loss = closure()
    123 for group in self.param_groups:
    124     params_with_grad = []

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:101, in PrecisionPlugin._wrap_closure(self, model, optimizer, closure)
     89 def _wrap_closure(
     90     self,
     91     model: "pl.LightningModule",
     92     optimizer: Optimizer,
     93     closure: Callable[[], Any],
     94 ) -> Any:
     95     """This double-closure allows makes sure the ``closure`` is executed before the
     96     ``on_before_optimizer_step`` hook is called.
     97 
     98     The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
     99     consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    100     """
--> 101     closure_result = closure()
    102     self._after_closure(model, optimizer)
    103     return closure_result

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
    125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126     step_output = self._step_fn()
    128     if step_output.closure_loss is None:
    129         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:308, in _AutomaticOptimization._training_step(self, kwargs)
    305 trainer = self.trainer
    307 # manually capture logged metrics
--> 308 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    309 self.trainer.strategy.post_training_step()
    311 result = self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:288, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    285     return
    287 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 288     output = fn(*args, **kwargs)
    290 # restore current_fx when nested context
    291 pl_module._current_fx_name = prev_fx_name

File ~/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:366, in Strategy.training_step(self, *args, **kwargs)
    364 with self.precision_plugin.train_step_context():
    365     assert isinstance(self.model, TrainingStep)
--> 366     return self.model.training_step(*args, **kwargs)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainingplans.py:797, in SemiSupervisedTrainingPlan.training_step(self, batch, batch_idx)
    792 input_kwargs = {
    793     "feed_labels": False,
    794     "labelled_tensors": labelled_dataset,
    795 }
    796 input_kwargs.update(self.loss_kwargs)
--> 797 _, _, loss_output = self.forward(full_dataset, loss_kwargs=input_kwargs)
    798 loss = loss_output.loss
    799 self.log(
    800     "train_loss",
    801     loss,
   (...)
    804     prog_bar=True,
    805 )

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/train/_trainingplans.py:278, in TrainingPlan.forward(self, *args, **kwargs)
    276 def forward(self, *args, **kwargs):
    277     """Passthrough to the module's forward method."""
--> 278     return self.module(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_base_module.py:199, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    165 @auto_move_data
    166 def forward(
    167     self,
   (...)
    177     | tuple[torch.Tensor, torch.Tensor, LossOutput]
    178 ):
    179     """Forward pass through the network.
    180 
    181     Parameters
   (...)
    197         another return value.
    198     """
--> 199     return _generic_forward(
    200         self,
    201         tensors,
    202         inference_kwargs,
    203         generative_kwargs,
    204         loss_kwargs,
    205         get_inference_input_kwargs,
    206         get_generative_input_kwargs,
    207         compute_loss,
    208     )

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_base_module.py:743, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    738 get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs)
    740 inference_inputs = module._get_inference_input(
    741     tensors, **get_inference_input_kwargs
    742 )
--> 743 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    744 generative_inputs = module._get_generative_input(
    745     tensors, inference_outputs, **get_generative_input_kwargs
    746 )
    747 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_base_module.py:303, in BaseMinifiedModeModuleClass.inference(self, *args, **kwargs)
    301     return self._cached_inference(*args, **kwargs)
    302 else:
--> 303     return self._regular_inference(*args, **kwargs)

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/module/_vae.py:336, in VAE._regular_inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    334 else:
    335     categorical_input = ()
--> 336 qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
    337 ql = None
    338 if not self.use_observed_lib_size:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/new_scvi/lib/python3.10/site-packages/scvi/nn/_base_components.py:289, in Encoder.forward(self, x, *cat_list)
    287 q_m = self.mean_encoder(q)
    288 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 289 dist = Normal(q_m, q_v.sqrt())
    290 latent = self.z_transformation(dist.rsample())
    291 if self.return_dist:

File ~/.local/lib/python3.10/site-packages/torch/distributions/normal.py:56, in Normal.__init__(self, loc, scale, validate_args)
     54 else:
     55     batch_shape = self.loc.size()
---> 56 super().__init__(batch_shape, validate_args=validate_args)

File ~/.local/lib/python3.10/site-packages/torch/distributions/distribution.py:62, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     60         valid = constraint.check(value)
     61         if not valid.all():
---> 62             raise ValueError(
     63                 f"Expected parameter {param} "
     64                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     65                 f"of distribution {repr(self)} "
     66                 f"to satisfy the constraint {repr(constraint)}, "
     67                 f"but found invalid values:\n{value}"
     68             )
     69 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (128, 30)) of distribution Normal(loc: torch.Size([128, 30]), scale: torch.Size([128, 30])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

Then, I try to re-run without changing the parameter, or I set lr to 0.0001, it didn’t fish the 1st epoch, with the same tensor nan error.

running lvae.view_anndata_setup(combined_adata):

Anndata setup with scvi-tools version 1.0.2.
Setup via `SCANVI.setup_anndata` with arguments:
{
β”‚   'labels_key': 'ct',
β”‚   'unlabeled_category': 'Unknown',
β”‚   'layer': 'counts',
β”‚   'batch_key': None,
β”‚   'size_factor_key': None,
β”‚   'categorical_covariate_keys': ['Batch'],
β”‚   'continuous_covariate_keys': ['pct_counts_mt', 'total_counts']
}
         Summary Statistics          
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃     Summary Stat Key     ┃ Value  ┃
┑━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
β”‚         n_batch          β”‚   1    β”‚
β”‚         n_cells          β”‚ 620233 β”‚
β”‚ n_extra_categorical_covs β”‚   1    β”‚
β”‚ n_extra_continuous_covs  β”‚   2    β”‚
β”‚         n_labels         β”‚   39   β”‚
β”‚          n_vars          β”‚ 14658  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                             Data Registry                             
┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Registry Key      ┃            scvi-tools Location             ┃
┑━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
β”‚           X            β”‚           adata.layers['counts']           β”‚
β”‚         batch          β”‚          adata.obs['_scvi_batch']          β”‚
β”‚ extra_categorical_covs β”‚ adata.obsm['_scvi_extra_categorical_covs'] β”‚
β”‚ extra_continuous_covs  β”‚ adata.obsm['_scvi_extra_continuous_covs']  β”‚
β”‚         labels         β”‚         adata.obs['_scvi_labels']          β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Thanks

1 Like

Hi, I have the same error. After training the scvi model with best param tuning, initialize a scANVI model with it. And after 14 epochs I got the same tensor nan error. Are you able to find a solution finally?

Hi, it’s a long time after the first post. However, we have a FAQ in the meantime for NaN errors: https://docs.scvi-tools.org/en/1.2.0/faq.html/
Specifically, for scANVI. Check once that your data has no very low count cells (in standard 10X below 200 counts after hvg filtering) and that you use count data as input. If the issues persists, you can set var_activation=torch.nn.softplus to increase stability during model training.

Thanks. I indeed found the discussion in scvi github. and did use var_activation=torch.nn.softplus as well as a very small learning rate. The problem is that I could not train for too many epochs and had to do an early stop otherwise the training and validation loss went up. I had to stop at around 86 epochs.

also, how do I make sure I used count data as input? the counts data is in the adata.layer[β€˜counts’]
My code is:
scanvi_model = scvi.model.SCANVI.from_scvi_model(
scvi_model,
adata=adata,
unlabeled_category=β€œunknown”,
labels_key=SCANVI_CELLTYPE_KEY,
linear_classifier=True,
var_activation=torch.nn.functional.softplus,
)

@cane11 By chance does this FAQ still exist? The link provided gives a 404 error and I’m running into similar NaN errors for both scANVI and totalANVI models. I was hoping to consult it before posting any questions that could be redundant.

Thanks,
Chris

Frequently asked questions β€” scvi-tools We are working on a more graceful shutdown when NaN are encountered.
The layer argument, you have to define this in the corresponding scVI model in setup-anndata. All models/datasets afterwards have to be consistent.