Running TOTALVI data in which subset of cells do not have citeseq data

Hello,

I am trying to use SCVI to analyze datasets from multiple sources. Since about 80% of the data contains citeseq information, I was trying to see if I could run totalVI on the data using the defined batch key. However, it looks as though because a subset of the batched experiments does not contain protein information, I am unable to do this and get a value error when it is trying to estimate parameters. I have confirmed that I can run totalVI when only looking at samples with protein expression, and SCVI when only looking at gene expression across all datasets.

Do you know if there is some process I can use to achiever the original intent? Run totalVI on all the data even though junks of the datasets are missing protein information?

Can you post your script and the traceback? Your workflow should work.

Here is the script that was attempted

# adata1 => has cite seq data from 5 datasets
# adata2 => gene expression only from 1 dataset
# adata3 => gene expression only from 1 dataset
adata_all = sc.concat([adata1, adata2, adata3])

# load cite seq matrix for adata1
cite_seq_ab_df = pd.read_csv('CiteseqMarkers.txt', sep='\t', index_col=0)
# add zero matrix for cell barcodes in adata2 and adata3
add_bc = adata_all.obs.index.difference(cite_seq_ab_df.index)
fill_in_data = pd.DataFrame(
    index=add_bc,
    columns=cite_seq_ab_df.columns,
)
total_vi_pexp = pd.concat([
    cite_seq_ab_df,
    fill_in_data
]).loc[adata_all.obs.index].fillna(0)
adata_all.obsm['protein_expression'] = total_vi_pexp
scvi.data.setup_anndata(
    adata_all, layer="counts", batch_key="sample", labels_key='label',
    protein_expression_obsm_key='protein_expression'
)

vae2 = scvi.model.TOTALVI(
    adata_all, 
    gene_dispersion='gene-batch',
    protein_dispersion='protein-batch',
    gene_likelihood='nb',
    latent_distribution='normal'
)
vae2.train()

The above script will raise a value error when it is trying to estimate parameters I believe.

I have a feeling of what it might be, but could you post the error? In the meantime, you might try not using "protein-batch" dispersion.

Actually it might be caused by a few antibody markers found at a very low count. Let me try and filtering out markers with low total counts (<5000) across all cells.

Ah yes OK it seems like some of the markers that had low background were causing the value error to occur. It seems to be running now, thanks!

Great, it would still be helpful for us to know exactly what error occurred. Could you paste it?

Here is the error message from before

ValueError                                Traceback (most recent call last)
<ipython-input-29-3de112adfeee> in <module>
----> 1 vae5.train()

/opt/conda/lib/python3.8/site-packages/scvi/model/_totalvi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, early_stopping, check_val_every_n_epoch, reduce_lr_on_plateau, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_classifier, plan_kwargs, **kwargs)
    257             **kwargs,
    258         )
--> 259         return runner()
    260 
    261     @torch.no_grad()

/opt/conda/lib/python3.8/site-packages/scvi/train/_trainrunner.py in __call__(self)
     73             self.trainer.fit(self.training_plan, train_dl)
     74         else:
---> 75             self.trainer.fit(self.training_plan, train_dl, val_dl)
     76         try:
     77             self.model.history_ = self.trainer.logger.history

/opt/conda/lib/python3.8/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    150                 message="you defined a validation_step but have no val_dataloader",
    151             )
--> 152             super().fit(*args, **kwargs)

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    496 
    497         # dispath `start_training` or `start_testing` or `start_predicting`
--> 498         self.dispatch()
    499 
    500         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    543 
    544         else:
--> 545             self.accelerator.start_training(self)
    546 
    547     def train_or_test_or_predict(self):

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     71 
     72     def start_training(self, trainer):
---> 73         self.training_type_plugin.start_training(trainer)
     74 
     75     def start_testing(self, trainer):

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    112     def start_training(self, trainer: 'Trainer') -> None:
    113         # double dispatch to initiate the training loop
--> 114         self._results = trainer.run_train()
    115 
    116     def start_testing(self, trainer: 'Trainer') -> None:

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    634                 with self.profiler.profile("run_training_epoch"):
    635                     # run train epoch
--> 636                     self.train_loop.run_training_epoch()
    637 
    638                 if self.max_steps and self.max_steps <= self.global_step:

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    491             # ------------------------------------
    492             with self.trainer.profiler.profile("run_training_batch"):
--> 493                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    494 
    495             # when returning -1 from train_step, we end epoch early

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    653 
    654                         # optimizer step
--> 655                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    656 
    657                     else:

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    424 
    425         # model hook
--> 426         model_ref.optimizer_step(
    427             self.trainer.current_epoch,
    428             batch_idx,

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1383             # wraps into LightingOptimizer only for running step
   1384             optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx)
-> 1385         optimizer.step(closure=optimizer_closure)
   1386 
   1387     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
    213 
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1
    216 

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
    132 
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    135 
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    275         )
    276         if make_optimizer_step:
--> 277             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    278         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    279         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    280 
    281     def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
--> 282         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    283 
    284     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    161 
    162     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 163         optimizer.step(closure=lambda_closure, **kwargs)

/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     87                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     88                 with torch.autograd.profiler.record_function(profile_name):
---> 89                     return func(*args, **kwargs)
     90             return wrapper
     91 

/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.__class__():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/opt/conda/lib/python3.8/site-packages/torch/optim/adam.py in step(self, closure)
     64         if closure is not None:
     65             with torch.enable_grad():
---> 66                 loss = closure()
     67 
     68         for group in self.param_groups:

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
    647 
    648                         def train_step_and_backward_closure():
--> 649                             result = self.training_step_and_backward(
    650                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    651                             )

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    741         with self.trainer.profiler.profile("training_step_and_backward"):
    742             # lightning module hook
--> 743             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    744             self._curr_step_result = result
    745 

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    291             model_ref._results = Result()
    292             with self.trainer.profiler.profile("training_step"):
--> 293                 training_step_output = self.trainer.accelerator.training_step(args)
    294                 self.trainer.accelerator.post_training_step()
    295 

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
    154 
    155         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 156             return self.training_type_plugin.training_step(*args)
    157 
    158     def post_training_step(self):

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
    123 
    124     def training_step(self, *args, **kwargs):
--> 125         return self.lightning_module.training_step(*args, **kwargs)
    126 
    127     def post_training_step(self):

/opt/conda/lib/python3.8/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
    346         if optimizer_idx == 1:
    347             inference_inputs = self.module._get_inference_input(batch)
--> 348             outputs = self.module.inference(**inference_inputs)
    349             z = outputs["z"]
    350             loss = self.loss_adversarial_classifier(z.detach(), batch_tensor, True)

/opt/conda/lib/python3.8/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

/opt/conda/lib/python3.8/site-packages/scvi/module/_totalvae.py in inference(self, x, y, batch_index, label, n_samples, transform_batch, cont_covs, cat_covs)
    436         else:
    437             categorical_input = tuple()
--> 438         qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
    439             encoder_input, batch_index, *categorical_input
    440         )

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/opt/conda/lib/python3.8/site-packages/scvi/nn/_base_components.py in forward(self, data, *cat_list)
    984         qz_m = self.z_mean_encoder(q)
    985         qz_v = torch.exp(self.z_var_encoder(q)) + 1e-4
--> 986         z, untran_z = self.reparameterize_transformation(qz_m, qz_v)
    987 
    988         ql_gene = self.l_gene_encoder(data, *cat_list)

/opt/conda/lib/python3.8/site-packages/scvi/nn/_base_components.py in reparameterize_transformation(self, mu, var)
    950 
    951     def reparameterize_transformation(self, mu, var):
--> 952         untran_z = Normal(mu, var.sqrt()).rsample()
    953         z = self.z_transformation(untran_z)
    954         return z, untran_z

/opt/conda/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 
     52     def expand(self, batch_shape, _instance=None):

/opt/conda/lib/python3.8/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     51                     continue  # skip checking lazily-constructed args
     52                 if not constraint.check(getattr(self, param)).all():
---> 53                     raise ValueError("The parameter {} has invalid values".format(param))
     54         super(Distribution, self).__init__()
     55 

ValueError: The parameter loc has invalid values

Thanks! It could also have to do with our empirical initialization of the background parameters, which is new. See empirical_protein_background_prior here.