MULTIVI training fails before first epoch

Hi, I am trying to integrate mostly paired scRNA and scATAC data following your MULTIVI tutorial. Creating the mvi anndata and setting up the model with

scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality')

works fine. However, training the model with

mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)
mvi.train()

results in the error:

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Epoch 1/500:   0%|          | 0/500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-14-b4772eb1a555> in <module>
      4     n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
      5 )
----> 6 mvi.train()

~/miniconda3/lib/python3.7/site-packages/scvi/model/_multivi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, weight_decay, eps, early_stopping, save_best, check_val_every_n_epoch, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_mixing, plan_kwargs, **kwargs)
    278             **kwargs,
    279         )
--> 280         return runner()
    281 
    282     @torch.no_grad()

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainrunner.py in __call__(self)
     70             self.training_plan.n_obs_training = self.data_splitter.n_train
     71 
---> 72         self.trainer.fit(self.training_plan, self.data_splitter)
     73         self._update_history()
     74 

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    175                     message="`LightningModule.configure_optimizers` returned `None`",
    176                 )
--> 177             super().fit(*args, **kwargs)

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    458         )
    459 
--> 460         self._run(model)
    461 
    462         assert self.state.stopped

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    756 
    757         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758         self.dispatch()
    759 
    760         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    797             self.accelerator.start_predicting(self)
    798         else:
--> 799             self.accelerator.start_training(self)
    800 
    801     def run_stage(self):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    807         if self.predicting:
    808             return self.run_predict()
--> 809         return self.run_train()
    810 
    811     def _pre_training_routine(self):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    869                 with self.profiler.profile("run_training_epoch"):
    870                     # run train epoch
--> 871                     self.train_loop.run_training_epoch()
    872 
    873                 if self.max_steps and self.max_steps <= self.global_step:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    497             # ------------------------------------
    498             with self.trainer.profiler.profile("run_training_batch"):
--> 499                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    500 
    501             # when returning -1 from train_step, we end epoch early

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    736 
    737                         # optimizer step
--> 738                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    739                         if len(self.trainer.optimizers) > 1:
    740                             # revert back to previous state

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    440             on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE,
    441             using_native_amp=using_native_amp,
--> 442             using_lbfgs=is_lbfgs,
    443         )
    444 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1401 
   1402         """
-> 1403         optimizer.step(closure=optimizer_closure)
   1404 
   1405     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
    213 
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1
    216 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
    132 
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    135 
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    327         )
    328         if make_optimizer_step:
--> 329             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    330         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    331         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    334         self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
    335     ) -> None:
--> 336         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    337 
    338     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    191 
    192     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 193         optimizer.step(closure=lambda_closure, **kwargs)
    194 
    195     @property

~/miniconda3/lib/python3.7/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     86                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87                 with torch.autograd.profiler.record_function(profile_name):
---> 88                     return func(*args, **kwargs)
     89             return wrapper
     90 

~/miniconda3/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

~/miniconda3/lib/python3.7/site-packages/torch/optim/adam.py in step(self, closure)
     90         if closure is not None:
     91             with torch.enable_grad():
---> 92                 loss = closure()
     93 
     94         for group in self.param_groups:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
    731                         def train_step_and_backward_closure():
    732                             result = self.training_step_and_backward(
--> 733                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    734                             )
    735                             return None if result is None else result.loss

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    821         with self.trainer.profiler.profile("training_step_and_backward"):
    822             # lightning module hook
--> 823             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    824             self._curr_step_result = result
    825 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    288             model_ref._results = Result()
    289             with self.trainer.profiler.profile("training_step"):
--> 290                 training_step_output = self.trainer.accelerator.training_step(args)
    291                 self.trainer.accelerator.post_training_step()
    292 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
    202 
    203         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 204             return self.training_type_plugin.training_step(*args)
    205 
    206     def post_training_step(self) -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
    153 
    154     def training_step(self, *args, **kwargs):
--> 155         return self.lightning_module.training_step(*args, **kwargs)
    156 
    157     def post_training_step(self):

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
    362             loss_kwargs = dict(kl_weight=self.kl_weight)
    363             inference_outputs, _, scvi_loss = self.forward(
--> 364                 batch, loss_kwargs=loss_kwargs
    365             )
    366             loss = scvi_loss.loss

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainingplans.py in forward(self, *args, **kwargs)
    145     def forward(self, *args, **kwargs):
    146         """Passthrough to `model.forward()`."""
--> 147         return self.module(*args, **kwargs)
    148 
    149     def training_step(self, batch, batch_idx, optimizer_idx=0):

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_base_module.py in forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    143             tensors, **get_inference_input_kwargs
    144         )
--> 145         inference_outputs = self.inference(**inference_inputs, **inference_kwargs)
    146         generative_inputs = self._get_generative_input(
    147             tensors, inference_outputs, **get_generative_input_kwargs

~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

~/miniconda3/lib/python3.7/site-packages/scvi/module/_multivae.py in inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    293         # Z Encoders
    294         qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
--> 295             encoder_input_accessibility, batch_index, *categorical_input
    296         )
    297         qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/lib/python3.7/site-packages/scvi/nn/_base_components.py in forward(self, x, *cat_list)
    292         q_m = self.mean_encoder(q)
    293         q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 294         latent = self.z_transformation(reparameterize_gaussian(q_m, q_v))
    295         return q_m, q_v, latent
    296 

~/miniconda3/lib/python3.7/site-packages/scvi/nn/_base_components.py in reparameterize_gaussian(mu, var)
     11 
     12 def reparameterize_gaussian(mu, var):
---> 13     return Normal(mu, var.sqrt()).rsample()
     14 
     15 

~/miniconda3/lib/python3.7/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 
     52     def expand(self, batch_shape, _instance=None):

~/miniconda3/lib/python3.7/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     54                 if not valid.all():
     55                     raise ValueError(
---> 56                         f"Expected parameter {param} "
     57                         f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58                         f"of distribution {repr(self)} "

ValueError: Expected parameter loc (Tensor of shape (128, 19)) of distribution Normal(loc: torch.Size([128, 19]), scale: torch.Size([128, 19])) to satisfy the constraint Real(), but found invalid values:
tensor([[ 0.1109,  0.5803,  0.3902,  ..., -0.4835, -0.8638,  0.0870],
        [ 0.7019,  0.4671,  0.4204,  ...,  0.0102, -0.8212,  0.0126],
        [ 0.1285,  1.1512, -0.1905,  ..., -0.4889, -0.0262,  0.0351],
        ...,
        [-0.0526,  0.3792,  0.6689,  ...,  0.2085,  0.0496,  0.4914],
        [ 0.5148,  0.4604,  0.4606,  ..., -0.0603, -0.3616,  0.4082],
        [ 0.5613,  0.6148,  0.5383,  ..., -0.1725, -1.2356, -0.1374]],
       grad_fn=<AddmmBackward0>)

Can you help me with this?

Thanks in advance!
Moritz

Discussion here: