totalVI NaN loss with few proteins

Thank you for sharing these awesome tools with the bioinformatics community.

After using totalVI without any issues, I recently started experiencing problems in training my totalVI models and am writing to solicit some advice on how to resolve this. The totalVI example in the tutorial continues to run fine, but when I try something analogous for my data, I see the following:

 import numpy as np

x = concatenated_adata.copy()

        layer = 'counts',
                        batch_key = "batch",
                        layer = 'counts',
                        protein_expression_obsm_key = "protein_expression")

x_model = scvi.model.TOTALVI(x, latent_distribution = "normal", n_layers_decoder = 2)

x_model.train(max_epochs = 50,
              use_gpu = True)

which results in the following error(s):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/ipykernel/ DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
If you pass `n_top_genes`, all cutoffs are ignored.
extracting highly variable genes
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/ SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scanpy/preprocessing/ FutureWarning: Slicing a positional slice with .loc is not supported, and will raise TypeError in a future version.  Use .loc with labels or .iloc with positions instead.
  df.loc[: int(n_top_genes), 'highly_variable'] = True
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)
INFO     Using batches from adata.obs["batch"]                                               
INFO     No label_key inputted, assuming all cells have same label                           
INFO     Using data from adata.layers["counts"]                                              
INFO     Computing library size prior per batch                                              
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/ SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/ SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/ SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/ SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/ SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/ SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
INFO     Using protein expression from adata.obsm['protein_expression']                      
INFO     Using protein names from columns of adata.obsm['protein_expression']                
INFO     Found batches with missing protein expression                                       
INFO     Successfully registered anndata object containing 35523 cells, 4000 vars, 6 batches,
         1 labels, and 3 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
INFO     Please do not further modify adata until model is trained.                          
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Set SLURM handle signals.
Epoch 1/50:   0%|          | 0/50 [00:03<?, ?it/s]
ValueError                                Traceback (most recent call last)
<ipython-input-13-50c61792beb1> in <module>
     21 x_model = scvi.model.TOTALVI(x, latent_distribution = "normal", n_layers_decoder = 2)
---> 23 x_model.train(max_epochs = 50,
     24               use_gpu = True)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/model/ in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, early_stopping, check_val_every_n_epoch, reduce_lr_on_plateau, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_classifier, plan_kwargs, **kwargs)
    257             **kwargs,
    258         )
--> 259         return runner()
    261     @torch.no_grad()

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/ in __call__(self)
     73   , train_dl)
     74         else:
---> 75   , train_dl, val_dl)
     76         try:
     77             self.model.history_ = self.trainer.logger.history

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/ in fit(self, *args, **kwargs)
    150                 message="you defined a validation_step but have no val_dataloader",
    151             )
--> 152             super().fit(*args, **kwargs)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    513         # dispath `start_training` or `start_testing` or `start_predicting`
--> 514         self.dispatch()
    516         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in dispatch(self)
    553         else:
--> 554             self.accelerator.start_training(self)
    556     def train_or_test_or_predict(self):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ in start_training(self, trainer)
     73     def start_training(self, trainer):
---> 74         self.training_type_plugin.start_training(trainer)
     76     def start_testing(self, trainer):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ in start_training(self, trainer)
    109     def start_training(self, trainer: 'Trainer') -> None:
    110         # double dispatch to initiate the training loop
--> 111         self._results = trainer.run_train()
    113     def start_testing(self, trainer: 'Trainer') -> None:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in run_train(self)
    643                 with self.profiler.profile("run_training_epoch"):
    644                     # run train epoch
--> 645                     self.train_loop.run_training_epoch()
    647                 if self.max_steps and self.max_steps <= self.global_step:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in run_training_epoch(self)
    491             # ------------------------------------
    492             with self.trainer.profiler.profile("run_training_batch"):
--> 493                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    495             # when returning -1 from train_step, we end epoch early

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in run_training_batch(self, batch, batch_idx, dataloader_idx)
    654                         # optimizer step
--> 655                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    657                     else:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    425         # model hook
--> 426         model_ref.optimizer_step(
    427             self.trainer.current_epoch,
    428             batch_idx,

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/ in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1382             # wraps into LightingOptimizer only for running step
   1383             optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx)
-> 1384         optimizer.step(closure=optimizer_closure)
   1386     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/ in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/ in __optimizer_step(self, closure, profiler_name, **kwargs)
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    276         )
    277         if make_optimizer_step:
--> 278             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    279         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    280         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    282     def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
--> 283         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    285     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    159     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 160         optimizer.step(closure=lambda_closure, **kwargs)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/optim/ in wrapper(*args, **kwargs)
     87                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     88                 with torch.autograd.profiler.record_function(profile_name):
---> 89                     return func(*args, **kwargs)
     90             return wrapper

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/autograd/ in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.__class__():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/optim/ in step(self, closure)
     64         if closure is not None:
     65             with torch.enable_grad():
---> 66                 loss = closure()
     68         for group in self.param_groups:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in train_step_and_backward_closure()
    648                         def train_step_and_backward_closure():
--> 649                             result = self.training_step_and_backward(
    650                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    651                             )

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    741         with self.trainer.profiler.profile("training_step_and_backward"):
    742             # lightning module hook
--> 743             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    744             self._curr_step_result = result

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/ in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    291             model_ref._results = Result()
    292             with self.trainer.profiler.profile("training_step"):
--> 293                 training_step_output = self.trainer.accelerator.training_step(args)
    294                 self.trainer.accelerator.post_training_step()

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ in training_step(self, args)
    156         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 157             return self.training_type_plugin.training_step(*args)
    159     def post_training_step(self):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ in training_step(self, *args, **kwargs)
    121     def training_step(self, *args, **kwargs):
--> 122         return self.lightning_module.training_step(*args, **kwargs)
    124     def post_training_step(self):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/ in training_step(self, batch, batch_idx, optimizer_idx)
    346         if optimizer_idx == 1:
    347             inference_inputs = self.module._get_inference_input(batch)
--> 348             outputs = self.module.inference(**inference_inputs)
    349             z = outputs["z"]
    350             loss = self.loss_adversarial_classifier(z.detach(), batch_tensor, True)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/module/base/ in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if
---> 32             return fn(self, *args, **kwargs)
     34         device = list(set(p.device for p in self.parameters()))

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/module/ in inference(self, x, y, batch_index, label, n_samples, transform_batch, cont_covs, cat_covs)
    436         else:
    437             categorical_input = tuple()
--> 438         qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
    439             encoder_input, batch_index, *categorical_input
    440         )

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/nn/ in forward(self, data, *cat_list)
    984         qz_m = self.z_mean_encoder(q)
    985         qz_v = torch.exp(self.z_var_encoder(q)) + 1e-4
--> 986         z, untran_z = self.reparameterize_transformation(qz_m, qz_v)
    988         ql_gene = self.l_gene_encoder(data, *cat_list)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/nn/ in reparameterize_transformation(self, mu, var)
    951     def reparameterize_transformation(self, mu, var):
--> 952         untran_z = Normal(mu, var.sqrt()).rsample()
    953         z = self.z_transformation(untran_z)
    954         return z, untran_z

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/distributions/ in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     52     def expand(self, batch_shape, _instance=None):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/distributions/ in __init__(self, batch_shape, event_shape, validate_args)
     51                     continue  # skip checking lazily-constructed args
     52                 if not constraint.check(getattr(self, param)).all():
---> 53                     raise ValueError("The parameter {} has invalid values".format(param))
     54         super(Distribution, self).__init__()

ValueError: The parameter loc has invalid values

This is the AnnData setup summary:

Anndata setup with scvi-tools version 0.9.0.
              Data Summary              
┃             Data             ┃ Count ┃
│            Cells             │ 35523 │
│             Vars             │ 10620 │
│            Labels            │   1   │
│           Batches            │   6   │
│           Proteins           │   3   │
│ Extra Categorical Covariates │   0   │
│ Extra Continuous Covariates  │   0   │
                   SCVI Data Registry                    
┃        Data        ┃       scvi-tools Location        ┃
│         X          │      adata.layers['counts']      │
│   batch_indices    │     adata.obs['_scvi_batch']     │
│    local_l_mean    │ adata.obs['_scvi_local_l_mean']  │
│    local_l_var     │  adata.obs['_scvi_local_l_var']  │
│       labels       │    adata.obs['_scvi_labels']     │
│ protein_expression │ adata.obsm['protein_expression'] │
                        Label Categories                        
┃      Source Location      ┃ Categories ┃ scvi-tools Encoding ┃
│ adata.obs['_scvi_labels'] │     0      │          0          │
                        Batch Categories                         
┃  Source Location   ┃     Categories     ┃ scvi-tools Encoding ┃
│ adata.obs['batch'] │    PB_baseline     │          0          │
│                    │   PB_primary_CMV   │          1          │
│                    │  PB_steady_state   │          2          │
│                    │  LN_steady_state   │          3          │
│                    │ PB_CMV_rechallenge │          4          │
│                    │ LN_CMV_rechallenge │          5          │

Any recommendations on what I might try? Thank in advance!

It’s hard to tell what might be the problem just from this traceback. But you might try two things:

  1. See if you can run your model with just scVI (ignore protein data for now) scvi.model.SCVI
  2. Turn down the learning rate of totalVI in the train method to e.g., 2e-3.

And just to double check, the protein data is count data yes?

I should have stated this: scVI and scANVI continue to work just fine as does the totalVI example in the scvi-tools tutorial. All attempts to run totalVI on my data, including turning down the learning rate, have been unsuccessful. Interestingly it completes a variable number of epochs - even without changing the learning rate - before raising this exception.

I confirm I am indeed trying to run totalVI on count data for both mRNA and protein.

I realise how incredibly annoying it is to share and develop nice code and then get asked support-type questions. I’m really just looking for your gut reaction: is it something with my data and should I focus my attention on finding some kind of problem there … or is it more worthwhile for me to try to step through your very nice code with ipdb or spyder?

Thanks again for the benefit of your advice.

It’s either the data or some hyperparam; though I’ve never seen this sort of error before on the mean of the latent space.

Does scVI run successfully on your data? If it does, I could quickly try to run totalVI if you are open to providing the data.

If scVI does not run successfully – are there any cells with all 0 counts?

Thanks for sending the data, upon some inspection you should do the following:

x_model = scvi.model.TOTALVI(adata, empirical_protein_background_prior=False, n_layers_decoder=2)

Basically I think since you have only 3 proteins, the new empirically learned prior initialization is getting thrown off and having bad values. This was something we added after the fact to try to better initialize the parameters that represent the protein background.

I can put a warning message if there are fewer than 10 proteins to alert users of this potential issue.

Thank so much!! Everything appears to work fine now. The data come from a species (and a cell type within that species) for which there aren’t many cross-reactive CITE-seq antibodies. Even those three antibodies help us quite a bit and thanks to totalVI we are able to overcome some of those limitations by mapping into references of other species with more CITE-seq antibodies.

Do you think it might make sense to have the default value of empirical_protein_background_prior to be (the boolean value of) n_proteins < 10?

Thanks again for sharing your awesome tool with this community.

Yes we can definitely do this. Would you be willing to make an issue on GitHub and link to this discussion?

The first thing you should understand is that SettingWithCopyWarning is a warning, and not an error. The real problem behind the warning is that it is generally difficult to predict whether a view or a copy is returned. In most cases, the warning was raised because you have chained two indexing operations together. The SettingWithCopyWarning was created to flag “chained assignment” operations. This is made easier to spot because you might be used [] (square brackets) twice, but the same would be true if you used other access methods such as .loc[] , .iloc[] and so on.

Moreover, you can change the behaviour of SettingWithCopyWarning warning using pd.options.mode.chained_assignment with three option “None/raise”/“warn”.