Hello,
I am running scvi 1.2.1 on jupyter notebook and I want to integrate two datasets using scvi. I checked and the count RNA does not contain NaN values, however when I run “model.train()”, I get the error below. Similar to what was reported here https://github.com/DendrouLab/panpipes/issues/306. However I am using the latest version of scvi. This is my code prior to the starting the model
concatenated_anndata.obs.groupby('dataset').count()
sc.pp.filter_genes(concatenated_anndata, min_counts=3)
concatenated_anndata.layers["counts"] = concatenated_anndata.X.copy() # preserve counts
sc.pp.normalize_total(concatenated_anndata, target_sum=1e4)
sc.pp.log1p(concatenated_anndata)
concatenated_anndata.raw = concatenated_anndata # freeze the state in `.raw`
sc.pp.highly_variable_genes(
concatenated_anndata,
n_top_genes=1200,
subset=True,
layer="counts",
batch_key="dataset",
)
scvi.model.SCVI.setup_anndata(
concatenated_anndata,
layer="counts",
categorical_covariate_keys=["dataset"])
model = scvi.model.SCVI(concatenated_anndata)
model.train()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[205], line 1
----> 1 model.train()
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/model/base/_training_mixin.py:145, in UnsupervisedTrainingMixin.train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, load_sparse_tensor, batch_size, early_stopping, datasplitter_kwargs, plan_kwargs, datamodule, **trainer_kwargs)
133 trainer_kwargs[es] = (
134 early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
135 )
136 runner = self._train_runner_cls(
137 self,
138 training_plan=training_plan,
(...)
143 **trainer_kwargs,
144 )
--> 145 return runner()
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainrunner.py:96, in TrainRunner.__call__(self)
93 if hasattr(self.data_splitter, "n_val"):
94 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 96 self.trainer.fit(self.training_plan, self.data_splitter)
97 self._update_history()
99 # data splitter only gets these attrs after fit
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainer.py:201, in Trainer.fit(self, *args, **kwargs)
195 if isinstance(args[0], PyroTrainingPlan):
196 warnings.filterwarnings(
197 action="ignore",
198 category=UserWarning,
199 message="`LightningModule.configure_optimizers` returned `None`",
200 )
--> 201 super().fit(*args, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
536 self.state.status = TrainerStatus.RUNNING
537 self.training = True
--> 538 call._call_and_handle_interrupt(
539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
540 )
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
567 assert self.state.fn is not None
568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
569 self.state.fn,
570 ckpt_path,
571 model_provided=True,
572 model_connected=self.lightning_module is not None,
573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
576 assert self.state.stopped
577 self.training = False
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
976 self._signal_connector.register_signal_handlers()
978 # ----------------------------
979 # RUN THE TRAINER
980 # ----------------------------
--> 981 results = self._run_stage()
983 # ----------------------------
984 # POST-Training CLEAN UP
985 # ----------------------------
986 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1025, in Trainer._run_stage(self)
1023 self._run_sanity_check()
1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025 self.fit_loop.run()
1026 return None
1027 raise RuntimeError(f"Unexpected state {self.state}")
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:205, in _FitLoop.run(self)
203 try:
204 self.on_advance_start()
--> 205 self.advance()
206 self.on_advance_end()
207 self._restarting = False
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:363, in _FitLoop.advance(self)
361 with self.trainer.profiler.profile("run_training_epoch"):
362 assert self._data_fetcher is not None
--> 363 self.epoch_loop.run(self._data_fetcher)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
138 while not self.done:
139 try:
--> 140 self.advance(data_fetcher)
141 self.on_advance_end(data_fetcher)
142 self._restarting = False
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:250, in _TrainingEpochLoop.advance(self, data_fetcher)
247 with trainer.profiler.profile("run_training_batch"):
248 if trainer.lightning_module.automatic_optimization:
249 # in automatic optimization, there can only be one optimizer
--> 250 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
251 else:
252 batch_output = self.manual_optimization.run(kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:190, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
183 closure()
185 # ------------------------------
186 # BACKWARD PASS
187 # ------------------------------
188 # gradient update with accumulated gradients
189 else:
--> 190 self._optimizer_step(batch_idx, closure)
192 result = closure.consume_result()
193 if result.loss is None:
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:268, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
265 self.optim_progress.optimizer.step.increment_ready()
267 # model hook
--> 268 call._call_lightning_module_hook(
269 trainer,
270 "optimizer_step",
271 trainer.current_epoch,
272 batch_idx,
273 optimizer,
274 train_step_and_backward_closure,
275 )
277 if not should_accumulate:
278 self.optim_progress.optimizer.step.increment_completed()
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
164 pl_module._current_fx_name = hook_name
166 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 167 output = fn(*args, **kwargs)
169 # restore current_fx when nested context
170 pl_module._current_fx_name = prev_fx_name
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/core/module.py:1306, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
1275 def optimizer_step(
1276 self,
1277 epoch: int,
(...)
1280 optimizer_closure: Optional[Callable[[], Any]] = None,
1281 ) -> None:
1282 r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1283 the optimizer.
1284
(...)
1304
1305 """
-> 1306 optimizer.step(closure=optimizer_closure)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py:153, in LightningOptimizer.step(self, closure, **kwargs)
150 raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
152 assert self._strategy is not None
--> 153 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
155 self._on_after_step()
157 return step_output
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:238, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
236 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
237 assert isinstance(model, pl.LightningModule)
--> 238 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py:122, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
120 """Hook to run the optimizer step."""
121 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 122 return optimizer.step(closure=closure, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
482 else:
483 raise RuntimeError(
484 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
485 )
--> 487 out = func(*args, **kwargs)
488 self._optimizer_step_code()
490 # call optimizer step post hooks
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/optimizer.py:91, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
89 torch.set_grad_enabled(self.defaults["differentiable"])
90 torch._dynamo.graph_break()
---> 91 ret = func(self, *args, **kwargs)
92 finally:
93 torch._dynamo.graph_break()
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/optim/adam.py:202, in Adam.step(self, closure)
200 if closure is not None:
201 with torch.enable_grad():
--> 202 loss = closure()
204 for group in self.param_groups:
205 params_with_grad: List[Tensor] = []
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py:108, in Precision._wrap_closure(self, model, optimizer, closure)
95 def _wrap_closure(
96 self,
97 model: "pl.LightningModule",
98 optimizer: Steppable,
99 closure: Callable[[], Any],
100 ) -> Any:
101 """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
102 hook is called.
103
(...)
106
107 """
--> 108 closure_result = closure()
109 self._after_closure(model, optimizer)
110 return closure_result
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:144, in Closure.__call__(self, *args, **kwargs)
142 @override
143 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 144 self._result = self.closure(*args, **kwargs)
145 return self._result.loss
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:129, in Closure.closure(self, *args, **kwargs)
126 @override
127 @torch.enable_grad()
128 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 129 step_output = self._step_fn()
131 if step_output.closure_loss is None:
132 self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:317, in _AutomaticOptimization._training_step(self, kwargs)
306 """Performs the actual train step with the tied hooks.
307
308 Args:
(...)
313
314 """
315 trainer = self.trainer
--> 317 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
318 self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
320 if training_step_output is None and trainer.world_size > 1:
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
316 return None
318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319 output = fn(*args, **kwargs)
321 # restore current_fx when nested context
322 pl_module._current_fx_name = prev_fx_name
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:390, in Strategy.training_step(self, *args, **kwargs)
388 if self.model != self.lightning_module:
389 return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 390 return self.lightning_module.training_step(*args, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainingplans.py:350, in TrainingPlan.training_step(self, batch, batch_idx)
348 self.loss_kwargs.update({"kl_weight": kl_weight})
349 self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
--> 350 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
351 self.log(
352 "train_loss",
353 scvi_loss.loss,
(...)
356 sync_dist=self.use_sync_dist,
357 )
358 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/train/_trainingplans.py:280, in TrainingPlan.forward(self, *args, **kwargs)
278 def forward(self, *args, **kwargs):
279 """Passthrough to the module's forward method."""
--> 280 return self.module(
281 *args,
282 **kwargs,
283 get_inference_input_kwargs={"full_forward_pass": not self.update_only_decoder},
284 )
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:207, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
176 @auto_move_data
177 def forward(
178 self,
(...)
185 compute_loss=True,
186 ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]:
187 """Forward pass through the network.
188
189 Parameters
(...)
205 another return value.
206 """
--> 207 return _generic_forward(
208 self,
209 tensors,
210 inference_kwargs,
211 generative_kwargs,
212 loss_kwargs,
213 get_inference_input_kwargs,
214 get_generative_input_kwargs,
215 compute_loss,
216 )
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:747, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
744 get_inference_input_kwargs.pop("full_forward_pass", None)
746 inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs)
--> 747 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
748 generative_inputs = module._get_generative_input(
749 tensors, inference_outputs, **get_generative_input_kwargs
750 )
751 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_base_module.py:308, in BaseMinifiedModeModuleClass.inference(self, *args, **kwargs)
306 return self._cached_inference(*args, **kwargs)
307 else:
--> 308 return self._regular_inference(*args, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/module/_vae.py:386, in VAE._regular_inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
384 qz, z = self.z_encoder(encoder_input, *categorical_input)
385 else:
--> 386 qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
388 ql = None
389 if not self.use_observed_lib_size:
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/scvi/nn/_base_components.py:286, in Encoder.forward(self, x, *cat_list)
284 q_m = self.mean_encoder(q)
285 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 286 dist = Normal(q_m, q_v.sqrt())
287 latent = self.z_transformation(dist.rsample())
288 if self.return_dist:
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/distributions/normal.py:59, in Normal.__init__(self, loc, scale, validate_args)
57 else:
58 batch_shape = self.loc.size()
---> 59 super().__init__(batch_shape, validate_args=validate_args)
File ~/.conda/envs/MyEnv/lib/python3.11/site-packages/torch/distributions/distribution.py:71, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
69 valid = constraint.check(value)
70 if not valid.all():
---> 71 raise ValueError(
72 f"Expected parameter {param} "
73 f"({type(value).__name__} of shape {tuple(value.shape)}) "
74 f"of distribution {repr(self)} "
75 f"to satisfy the constraint {repr(constraint)}, "
76 f"but found invalid values:\n{value}"
77 )
78 super().__init__()
ValueError: Expected parameter loc (Tensor of shape (128, 10)) of distribution Normal(loc: torch.Size([128, 10]), scale: torch.Size([128, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)```