Model cannot be initialized and return error

Hello,

I am running scvi 1.2.1 on jupyter notebook and I want to integrate two datasets using scvi. I checked and the count RNA does not contain NaN values, however when I run “model.train()”, I get the error below. Similar to what was reported here https://github.com/DendrouLab/panpipes/issues/306. However I am using the latest version of scvi. This is my code prior to the starting the model

concatenated_anndata.obs.groupby('dataset').count()
sc.pp.filter_genes(concatenated_anndata, min_counts=3)
concatenated_anndata.layers["counts"] = concatenated_anndata.X.copy()  # preserve counts
sc.pp.normalize_total(concatenated_anndata, target_sum=1e4)
sc.pp.log1p(concatenated_anndata)
concatenated_anndata.raw = concatenated_anndata  # freeze the state in `.raw`
sc.pp.highly_variable_genes(
    concatenated_anndata,
    n_top_genes=1200,
    subset=True,
    layer="counts",
    batch_key="dataset",
)

scvi.model.SCVI.setup_anndata(
    concatenated_anndata,
    layer="counts",
    categorical_covariate_keys=["dataset"])

model = scvi.model.SCVI(concatenated_anndata)

model.train()


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[205], line 1
----> 1 model.train()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/model/base/_training_mixin.py:145, in UnsupervisedTrainingMixin.train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, load_sparse_tensor, batch_size, early_stopping, datasplitter_kwargs, plan_kwargs, datamodule, **trainer_kwargs)
    133 trainer_kwargs[es] = (
    134     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
    135 )
    136 runner = self._train_runner_cls(
    137     self,
    138     training_plan=training_plan,
   (...)
    143     **trainer_kwargs,
    144 )
--> 145 return runner()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainrunner.py:96, in TrainRunner.__call__(self)
     93 if hasattr(self.data_splitter, "n_val"):
     94     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 96 self.trainer.fit(self.training_plan, self.data_splitter)
     97 self._update_history()
     99 # data splitter only gets these attrs after fit

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainer.py:201, in Trainer.fit(self, *args, **kwargs)
    195 if isinstance(args[0], PyroTrainingPlan):
    196     warnings.filterwarnings(
    197         action="ignore",
    198         category=UserWarning,
    199         message="`LightningModule.configure_optimizers` returned `None`",
    200     )
--> 201 super().fit(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1025, in Trainer._run_stage(self)
   1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025         self.fit_loop.run()
   1026     return None
   1027 raise RuntimeError(f"Unexpected state {self.state}")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:205, in _FitLoop.run(self)
    203 try:
    204     self.on_advance_start()
--> 205     self.advance()
    206     self.on_advance_end()
    207     self._restarting = False

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:363, in _FitLoop.advance(self)
    361 with self.trainer.profiler.profile("run_training_epoch"):
    362     assert self._data_fetcher is not None
--> 363     self.epoch_loop.run(self._data_fetcher)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
    138 while not self.done:
    139     try:
--> 140         self.advance(data_fetcher)
    141         self.on_advance_end(data_fetcher)
    142         self._restarting = False

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:250, in _TrainingEpochLoop.advance(self, data_fetcher)
    247 with trainer.profiler.profile("run_training_batch"):
    248     if trainer.lightning_module.automatic_optimization:
    249         # in automatic optimization, there can only be one optimizer
--> 250         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    251     else:
    252         batch_output = self.manual_optimization.run(kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:190, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    183         closure()
    185 # ------------------------------
    186 # BACKWARD PASS
    187 # ------------------------------
    188 # gradient update with accumulated gradients
    189 else:
--> 190     self._optimizer_step(batch_idx, closure)
    192 result = closure.consume_result()
    193 if result.loss is None:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:268, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    265     self.optim_progress.optimizer.step.increment_ready()
    267 # model hook
--> 268 call._call_lightning_module_hook(
    269     trainer,
    270     "optimizer_step",
    271     trainer.current_epoch,
    272     batch_idx,
    273     optimizer,
    274     train_step_and_backward_closure,
    275 )
    277 if not should_accumulate:
    278     self.optim_progress.optimizer.step.increment_completed()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    164 pl_module._current_fx_name = hook_name
    166 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 167     output = fn(*args, **kwargs)
    169 # restore current_fx when nested context
    170 pl_module._current_fx_name = prev_fx_name

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/core/module.py:1306, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1275 def optimizer_step(
   1276     self,
   1277     epoch: int,
   (...)
   1280     optimizer_closure: Optional[Callable[[], Any]] = None,
   1281 ) -> None:
   1282     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1283     the optimizer.
   1284 
   (...)
   1304 
   1305     """
-> 1306     optimizer.step(closure=optimizer_closure)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py:153, in LightningOptimizer.step(self, closure, **kwargs)
    150     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    152 assert self._strategy is not None
--> 153 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    155 self._on_after_step()
    157 return step_output

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:238, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    236 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    237 assert isinstance(model, pl.LightningModule)
--> 238 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py:122, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    120 """Hook to run the optimizer step."""
    121 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 122 return optimizer.step(closure=closure, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    482         else:
    483             raise RuntimeError(
    484                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    485             )
--> 487 out = func(*args, **kwargs)
    488 self._optimizer_step_code()
    490 # call optimizer step post hooks

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/optimizer.py:91, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     89     torch.set_grad_enabled(self.defaults["differentiable"])
     90     torch._dynamo.graph_break()
---> 91     ret = func(self, *args, **kwargs)
     92 finally:
     93     torch._dynamo.graph_break()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/adam.py:202, in Adam.step(self, closure)
    200 if closure is not None:
    201     with torch.enable_grad():
--> 202         loss = closure()
    204 for group in self.param_groups:
    205     params_with_grad: List[Tensor] = []

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py:108, in Precision._wrap_closure(self, model, optimizer, closure)
     95 def _wrap_closure(
     96     self,
     97     model: "pl.LightningModule",
     98     optimizer: Steppable,
     99     closure: Callable[[], Any],
    100 ) -> Any:
    101     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
    102     hook is called.
    103 
   (...)
    106 
    107     """
--> 108     closure_result = closure()
    109     self._after_closure(model, optimizer)
    110     return closure_result

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:144, in Closure.__call__(self, *args, **kwargs)
    142 @override
    143 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 144     self._result = self.closure(*args, **kwargs)
    145     return self._result.loss

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:129, in Closure.closure(self, *args, **kwargs)
    126 @override
    127 @torch.enable_grad()
    128 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 129     step_output = self._step_fn()
    131     if step_output.closure_loss is None:
    132         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:317, in _AutomaticOptimization._training_step(self, kwargs)
    306 """Performs the actual train step with the tied hooks.
    307 
    308 Args:
   (...)
    313 
    314 """
    315 trainer = self.trainer
--> 317 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    318 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    320 if training_step_output is None and trainer.world_size > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    316     return None
    318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context
    322 pl_module._current_fx_name = prev_fx_name

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:390, in Strategy.training_step(self, *args, **kwargs)
    388 if self.model != self.lightning_module:
    389     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 390 return self.lightning_module.training_step(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainingplans.py:350, in TrainingPlan.training_step(self, batch, batch_idx)
    348     self.loss_kwargs.update({"kl_weight": kl_weight})
    349     self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
--> 350 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
    351 self.log(
    352     "train_loss",
    353     scvi_loss.loss,
   (...)
    356     sync_dist=self.use_sync_dist,
    357 )
    358 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainingplans.py:280, in TrainingPlan.forward(self, *args, **kwargs)
    278 def forward(self, *args, **kwargs):
    279     """Passthrough to the module's forward method."""
--> 280     return self.module(
    281         *args,
    282         **kwargs,
    283         get_inference_input_kwargs={"full_forward_pass": not self.update_only_decoder},
    284     )

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:207, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    176 @auto_move_data
    177 def forward(
    178     self,
   (...)
    185     compute_loss=True,
    186 ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]:
    187     """Forward pass through the network.
    188 
    189     Parameters
   (...)
    205         another return value.
    206     """
--> 207     return _generic_forward(
    208         self,
    209         tensors,
    210         inference_kwargs,
    211         generative_kwargs,
    212         loss_kwargs,
    213         get_inference_input_kwargs,
    214         get_generative_input_kwargs,
    215         compute_loss,
    216     )

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:747, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    744     get_inference_input_kwargs.pop("full_forward_pass", None)
    746 inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs)
--> 747 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    748 generative_inputs = module._get_generative_input(
    749     tensors, inference_outputs, **get_generative_input_kwargs
    750 )
    751 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:308, in BaseMinifiedModeModuleClass.inference(self, *args, **kwargs)
    306     return self._cached_inference(*args, **kwargs)
    307 else:
--> 308     return self._regular_inference(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/_vae.py:386, in VAE._regular_inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    384     qz, z = self.z_encoder(encoder_input, *categorical_input)
    385 else:
--> 386     qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
    388 ql = None
    389 if not self.use_observed_lib_size:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/nn/_base_components.py:286, in Encoder.forward(self, x, *cat_list)
    284 q_m = self.mean_encoder(q)
    285 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 286 dist = Normal(q_m, q_v.sqrt())
    287 latent = self.z_transformation(dist.rsample())
    288 if self.return_dist:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/distributions/normal.py:59, in Normal.__init__(self, loc, scale, validate_args)
     57 else:
     58     batch_shape = self.loc.size()
---> 59 super().__init__(batch_shape, validate_args=validate_args)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/distributions/distribution.py:71, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     69         valid = constraint.check(value)
     70         if not valid.all():
---> 71             raise ValueError(
     72                 f"Expected parameter {param} "
     73                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     74                 f"of distribution {repr(self)} "
     75                 f"to satisfy the constraint {repr(constraint)}, "
     76                 f"but found invalid values:\n{value}"
     77             )
     78 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (128, 10)) of distribution Normal(loc: torch.Size([128, 10]), scale: torch.Size([128, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)```

Hi, your min_genes filter is very low with 3. afterwards you perform hvg filtering to few genes. Can you please check the count distribution after this step. Especially, whether you end up with any cells with zero counts. These are not supported in scVI. You can perform the filtering to 3 genes after hvg filtering. However results are usually bad for cells with less than ~500 counts (for standard 10X sequencing on ‚normal‘ cells.

Hi,

Thanks for your reply. I checked the RNAcount and it looks good to me, and there is no 0 value:

I also tried to perform only hvg filtering:

concatenated_anndata.obs.groupby('dataset').count()
concatenated_anndata.layers["counts"] = concatenated_anndata.X.copy()  # preserve counts
sc.pp.normalize_total(concatenated_anndata, target_sum=1e4)
sc.pp.log1p(concatenated_anndata)
concatenated_anndata.raw = concatenated_anndata  # freeze the state in `.raw`
sc.pp.highly_variable_genes(concatenated_anndata,n_top_genes=10000,subset=True,layer="counts",batch_key="dataset")

And again I checked the distribution of RNA count before initializing the model and there is no zero values. However I still get the same error when starting the model…

Hey @Slack90

Seems your input is ok.

the issue might be due to exploding gradients during training, you can set gradient clipping in the train function and it should help, e.g:
model.train(…,gradient_clip_val=1)

the default is 0.
of course that might also affect the results. there is no rule of thumb on the exact number needed, but trial and error.

Also run with early stopping and large max epochs.

Hi @ori-kron-wis,

Thanks for the reply. I ran it again with gradient_clip_val =1 (I also tried =10 and other values). I also added early stoppings and large max epochs as follows:

model.train(gradient_clip_val=1, max_epochs=4000, early_stopping=True)```

But I get the same error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[23], line 2
      1 # model.train(accelerator="gpu", devices="auto", gradient_clip_val=10)
----> 2 model.train(gradient_clip_val=1, max_epochs=400, early_stopping=True)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/model/base/_training_mixin.py:145, in UnsupervisedTrainingMixin.train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, load_sparse_tensor, batch_size, early_stopping, datasplitter_kwargs, plan_kwargs, datamodule, **trainer_kwargs)
    133 trainer_kwargs[es] = (
    134     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
    135 )
    136 runner = self._train_runner_cls(
    137     self,
    138     training_plan=training_plan,
   (...)
    143     **trainer_kwargs,
    144 )
--> 145 return runner()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainrunner.py:96, in TrainRunner.__call__(self)
     93 if hasattr(self.data_splitter, "n_val"):
     94     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 96 self.trainer.fit(self.training_plan, self.data_splitter)
     97 self._update_history()
     99 # data splitter only gets these attrs after fit

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainer.py:201, in Trainer.fit(self, *args, **kwargs)
    195 if isinstance(args[0], PyroTrainingPlan):
    196     warnings.filterwarnings(
    197         action="ignore",
    198         category=UserWarning,
    199         message="`LightningModule.configure_optimizers` returned `None`",
    200     )
--> 201 super().fit(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1025, in Trainer._run_stage(self)
   1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025         self.fit_loop.run()
   1026     return None
   1027 raise RuntimeError(f"Unexpected state {self.state}")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:205, in _FitLoop.run(self)
    203 try:
    204     self.on_advance_start()
--> 205     self.advance()
    206     self.on_advance_end()
    207     self._restarting = False

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:363, in _FitLoop.advance(self)
    361 with self.trainer.profiler.profile("run_training_epoch"):
    362     assert self._data_fetcher is not None
--> 363     self.epoch_loop.run(self._data_fetcher)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
    138 while not self.done:
    139     try:
--> 140         self.advance(data_fetcher)
    141         self.on_advance_end(data_fetcher)
    142         self._restarting = False

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:250, in _TrainingEpochLoop.advance(self, data_fetcher)
    247 with trainer.profiler.profile("run_training_batch"):
    248     if trainer.lightning_module.automatic_optimization:
    249         # in automatic optimization, there can only be one optimizer
--> 250         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    251     else:
    252         batch_output = self.manual_optimization.run(kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:190, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    183         closure()
    185 # ------------------------------
    186 # BACKWARD PASS
    187 # ------------------------------
    188 # gradient update with accumulated gradients
    189 else:
--> 190     self._optimizer_step(batch_idx, closure)
    192 result = closure.consume_result()
    193 if result.loss is None:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:268, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    265     self.optim_progress.optimizer.step.increment_ready()
    267 # model hook
--> 268 call._call_lightning_module_hook(
    269     trainer,
    270     "optimizer_step",
    271     trainer.current_epoch,
    272     batch_idx,
    273     optimizer,
    274     train_step_and_backward_closure,
    275 )
    277 if not should_accumulate:
    278     self.optim_progress.optimizer.step.increment_completed()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    164 pl_module._current_fx_name = hook_name
    166 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 167     output = fn(*args, **kwargs)
    169 # restore current_fx when nested context
    170 pl_module._current_fx_name = prev_fx_name

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/core/module.py:1306, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1275 def optimizer_step(
   1276     self,
   1277     epoch: int,
   (...)
   1280     optimizer_closure: Optional[Callable[[], Any]] = None,
   1281 ) -> None:
   1282     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1283     the optimizer.
   1284 
   (...)
   1304 
   1305     """
-> 1306     optimizer.step(closure=optimizer_closure)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py:153, in LightningOptimizer.step(self, closure, **kwargs)
    150     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    152 assert self._strategy is not None
--> 153 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    155 self._on_after_step()
    157 return step_output

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:238, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    236 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    237 assert isinstance(model, pl.LightningModule)
--> 238 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py:122, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    120 """Hook to run the optimizer step."""
    121 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 122 return optimizer.step(closure=closure, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    482         else:
    483             raise RuntimeError(
    484                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    485             )
--> 487 out = func(*args, **kwargs)
    488 self._optimizer_step_code()
    490 # call optimizer step post hooks

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/optimizer.py:91, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     89     torch.set_grad_enabled(self.defaults["differentiable"])
     90     torch._dynamo.graph_break()
---> 91     ret = func(self, *args, **kwargs)
     92 finally:
     93     torch._dynamo.graph_break()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/adam.py:202, in Adam.step(self, closure)
    200 if closure is not None:
    201     with torch.enable_grad():
--> 202         loss = closure()
    204 for group in self.param_groups:
    205     params_with_grad: List[Tensor] = []

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py:108, in Precision._wrap_closure(self, model, optimizer, closure)
     95 def _wrap_closure(
     96     self,
     97     model: "pl.LightningModule",
     98     optimizer: Steppable,
     99     closure: Callable[[], Any],
    100 ) -> Any:
    101     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
    102     hook is called.
    103 
   (...)
    106 
    107     """
--> 108     closure_result = closure()
    109     self._after_closure(model, optimizer)
    110     return closure_result

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:144, in Closure.__call__(self, *args, **kwargs)
    142 @override
    143 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 144     self._result = self.closure(*args, **kwargs)
    145     return self._result.loss

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:129, in Closure.closure(self, *args, **kwargs)
    126 @override
    127 @torch.enable_grad()
    128 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 129     step_output = self._step_fn()
    131     if step_output.closure_loss is None:
    132         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:317, in _AutomaticOptimization._training_step(self, kwargs)
    306 """Performs the actual train step with the tied hooks.
    307 
    308 Args:
   (...)
    313 
    314 """
    315 trainer = self.trainer
--> 317 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    318 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    320 if training_step_output is None and trainer.world_size > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    316     return None
    318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context
    322 pl_module._current_fx_name = prev_fx_name

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:390, in Strategy.training_step(self, *args, **kwargs)
    388 if self.model != self.lightning_module:
    389     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 390 return self.lightning_module.training_step(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainingplans.py:350, in TrainingPlan.training_step(self, batch, batch_idx)
    348     self.loss_kwargs.update({"kl_weight": kl_weight})
    349     self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
--> 350 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
    351 self.log(
    352     "train_loss",
    353     scvi_loss.loss,
   (...)
    356     sync_dist=self.use_sync_dist,
    357 )
    358 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainingplans.py:280, in TrainingPlan.forward(self, *args, **kwargs)
    278 def forward(self, *args, **kwargs):
    279     """Passthrough to the module's forward method."""
--> 280     return self.module(
    281         *args,
    282         **kwargs,
    283         get_inference_input_kwargs={"full_forward_pass": not self.update_only_decoder},
    284     )

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:207, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    176 @auto_move_data
    177 def forward(
    178     self,
   (...)
    185     compute_loss=True,
    186 ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]:
    187     """Forward pass through the network.
    188 
    189     Parameters
   (...)
    205         another return value.
    206     """
--> 207     return _generic_forward(
    208         self,
    209         tensors,
    210         inference_kwargs,
    211         generative_kwargs,
    212         loss_kwargs,
    213         get_inference_input_kwargs,
    214         get_generative_input_kwargs,
    215         compute_loss,
    216     )

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:747, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    744     get_inference_input_kwargs.pop("full_forward_pass", None)
    746 inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs)
--> 747 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    748 generative_inputs = module._get_generative_input(
    749     tensors, inference_outputs, **get_generative_input_kwargs
    750 )
    751 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:308, in BaseMinifiedModeModuleClass.inference(self, *args, **kwargs)
    306     return self._cached_inference(*args, **kwargs)
    307 else:
--> 308     return self._regular_inference(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/_vae.py:386, in VAE._regular_inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    384     qz, z = self.z_encoder(encoder_input, *categorical_input)
    385 else:
--> 386     qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
    388 ql = None
    389 if not self.use_observed_lib_size:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/nn/_base_components.py:286, in Encoder.forward(self, x, *cat_list)
    284 q_m = self.mean_encoder(q)
    285 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 286 dist = Normal(q_m, q_v.sqrt())
    287 latent = self.z_transformation(dist.rsample())
    288 if self.return_dist:

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/distributions/normal.py:59, in Normal.__init__(self, loc, scale, validate_args)
     57 else:
     58     batch_shape = self.loc.size()
---> 59 super().__init__(batch_shape, validate_args=validate_args)

File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/distributions/distribution.py:71, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     69         valid = constraint.check(value)
     70         if not valid.all():
---> 71             raise ValueError(
     72                 f"Expected parameter {param} "
     73                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     74                 f"of distribution {repr(self)} "
     75                 f"to satisfy the constraint {repr(constraint)}, "
     76                 f"but found invalid values:\n{value}"
     77             )
     78 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (128, 10)) of distribution Normal(loc: torch.Size([128, 10]), scale: torch.Size([128, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)

Hi, to me it still looks like it doesn’t even train one step (no progress bar). This only happens if the initialization is off. To help you with this, access to the dataset is necessary (count filtering would be the most common issue but you checked it). If training fails latter during training (model.history will contain losses for several epochs) other tricks might be helpful described here: Frequently asked questions — scvi-tools.

If others look here, it turned out the count data was actually scaled. So it was a dataset problem and not an scvi-tools problem.

1 Like