MultiVI error on only paired data

Hello,

I was following the MultiVI tutorial using the same 10x Multiome dataset in the tutorial. I removed the part where the data is split into paired data, gene expression only data, and peak only data. i.e. there’s only paired data, adata_mvi = scvi.data.organize_multiome_anndatas(adata). Otherwise, the rest of the tutorial code is unchanged.

When I start training the model, I receive the following error before the first training epoch:

# First few stack traces omitted

File ~/notebooks/../../scvi-tools/scvi/train/_trainingplans.py:574, in AdversarialTrainingPlan.training_step(self, batch, batch_idx)
    571 else:
    572     opt1, opt2 = opts
--> 574 inference_outputs, _, scvi_loss = self.forward(
    575     batch, loss_kwargs=self.loss_kwargs
    576 )
    577 z = inference_outputs["z"]
    578 loss = scvi_loss.loss

File ~/notebooks/../../scvi-tools/scvi/train/_trainingplans.py:283, in TrainingPlan.forward(self, *args, **kwargs)
    281 def forward(self, *args, **kwargs):
    282     """Passthrough to the module's forward method."""
--> 283     return self.module(*args, **kwargs)

File /arc/project/st-jiaruid-1/yinian/pytorch2/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/notebooks/../../scvi-tools/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/notebooks/../../scvi-tools/scvi/module/base/_base_module.py:198, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    164 @auto_move_data
    165 def forward(
    166     self,
   (...)
    176     | tuple[torch.Tensor, torch.Tensor, LossOutput]
    177 ):
    178     """Forward pass through the network.
    179 
    180     Parameters
   (...)
    196         another return value.
    197     """
--> 198     return _generic_forward(
    199         self,
    200         tensors,
    201         inference_kwargs,
    202         generative_kwargs,
    203         loss_kwargs,
    204         get_inference_input_kwargs,
    205         get_generative_input_kwargs,
    206         compute_loss,
    207     )

File ~/notebooks/../../scvi-tools/scvi/module/base/_base_module.py:742, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    737 get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs)
    739 inference_inputs = module._get_inference_input(
    740     tensors, **get_inference_input_kwargs
    741 )
--> 742 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    743 generative_inputs = module._get_generative_input(
    744     tensors, inference_outputs, **get_generative_input_kwargs
    745 )
    746 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)

File ~/notebooks/../../scvi-tools/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/notebooks/../../scvi-tools/scvi/module/_multivae.py:615, in MULTIVAE.inference(self, x, y, batch_index, cont_covs, cat_covs, label, cell_idx, n_samples)
    612     categorical_input = ()
    614 # Z Encoders
--> 615 qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
    616     encoder_input_accessibility, batch_index, *categorical_input
    617 )
    618 qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(
    619     encoder_input_expression, batch_index, *categorical_input
    620 )
    621 qzm_pro, qzv_pro, z_pro = self.z_encoder_protein(
    622     encoder_input_protein, batch_index, *categorical_input
    623 )

File /arc/project/st-jiaruid-1/yinian/pytorch2/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/notebooks/../../scvi-tools/scvi/nn/_base_components.py:289, in Encoder.forward(self, x, *cat_list)
    287 q_m = self.mean_encoder(q)
    288 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 289 dist = Normal(q_m, q_v.sqrt())
    290 latent = self.z_transformation(dist.rsample())
    291 if self.return_dist:

File /arc/project/st-jiaruid-1/yinian/pytorch2/lib/python3.11/site-packages/torch/distributions/normal.py:56, in Normal.__init__(self, loc, scale, validate_args)
     54 else:
     55     batch_shape = self.loc.size()
---> 56 super().__init__(batch_shape, validate_args=validate_args)

File /arc/project/st-jiaruid-1/yinian/pytorch2/lib/python3.11/site-packages/torch/distributions/distribution.py:62, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     60         valid = constraint.check(value)
     61         if not valid.all():
---> 62             raise ValueError(
     63                 f"Expected parameter {param} "
     64                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     65                 f"of distribution {repr(self)} "
     66                 f"to satisfy the constraint {repr(constraint)}, "
     67                 f"but found invalid values:\n{value}"
     68             )
     69 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (128, 17)) of distribution Normal(loc: torch.Size([128, 17]), scale: torch.Size([128, 17])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)

When I run the tutorial as-is (with single-modality observations), this error does not occur. I have also tried this with other joint RNA+ATAC datasets and this occurred as well.

I do not know if this is a factor, but I am using Pytorch 2.0.1.

Thank you for any insight.

1 Like

I also face this problems. I don’t know whether it’s because there are some much Nan values in the single cell data

What I found is that so long as there is at least one cell with only RNA information and only one cell with ATAC information, it runs. So the workaround I applied was to copy one cell and unpair the modalities and feed them as single modality cells.