Thank you for sharing these awesome tools with the bioinformatics community.
After using totalVI without any issues, I recently started experiencing problems in training my totalVI models and am writing to solicit some advice on how to resolve this. The totalVI example in the tutorial continues to run fine, but when I try something analogous for my data, I see the following:
import numpy as np
x = concatenated_adata.copy()
sc.pp.highly_variable_genes(
x,
batch_key="batch",
flavor="seurat_v3",
layer = 'counts',
n_top_genes=4000,
subset=True
)
scvi.data.setup_anndata(x,
batch_key = "batch",
layer = 'counts',
protein_expression_obsm_key = "protein_expression")
x_model = scvi.model.TOTALVI(x, latent_distribution = "normal", n_layers_decoder = 2)
x_model.train(max_epochs = 50,
use_gpu = True)
which results in the following error(s):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
and should_run_async(code)
If you pass `n_top_genes`, all cutoffs are ignored.
extracting highly variable genes
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_sub[k] = df_sub[k].cat.remove_unused_categories()
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scanpy/preprocessing/_highly_variable_genes.py:144: FutureWarning: Slicing a positional slice with .loc is not supported, and will raise TypeError in a future version. Use .loc with labels or .iloc with positions instead.
df.loc[: int(n_top_genes), 'highly_variable'] = True
--> added
'highly_variable', boolean vector (adata.var)
'highly_variable_rank', float vector (adata.var)
'means', float vector (adata.var)
'variances', float vector (adata.var)
'variances_norm', float vector (adata.var)
INFO Using batches from adata.obs["batch"]
INFO No label_key inputted, assuming all cells have same label
INFO Using data from adata.layers["counts"]
INFO Computing library size prior per batch
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_sub[k] = df_sub[k].cat.remove_unused_categories()
INFO Using protein expression from adata.obsm['protein_expression']
INFO Using protein names from columns of adata.obsm['protein_expression']
INFO Found batches with missing protein expression
INFO Successfully registered anndata object containing 35523 cells, 4000 vars, 6 batches,
1 labels, and 3 proteins. Also registered 0 extra categorical covariates and 0 extra
continuous covariates.
INFO Please do not further modify adata until model is trained.
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Set SLURM handle signals.
Epoch 1/50: 0%| | 0/50 [00:03<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-13-50c61792beb1> in <module>
21 x_model = scvi.model.TOTALVI(x, latent_distribution = "normal", n_layers_decoder = 2)
22
---> 23 x_model.train(max_epochs = 50,
24 use_gpu = True)
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/model/_totalvi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, early_stopping, check_val_every_n_epoch, reduce_lr_on_plateau, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_classifier, plan_kwargs, **kwargs)
257 **kwargs,
258 )
--> 259 return runner()
260
261 @torch.no_grad()
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/_trainrunner.py in __call__(self)
73 self.trainer.fit(self.training_plan, train_dl)
74 else:
---> 75 self.trainer.fit(self.training_plan, train_dl, val_dl)
76 try:
77 self.model.history_ = self.trainer.logger.history
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
150 message="you defined a validation_step but have no val_dataloader",
151 )
--> 152 super().fit(*args, **kwargs)
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
512
513 # dispath `start_training` or `start_testing` or `start_predicting`
--> 514 self.dispatch()
515
516 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
552
553 else:
--> 554 self.accelerator.start_training(self)
555
556 def train_or_test_or_predict(self):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
72
73 def start_training(self, trainer):
---> 74 self.training_type_plugin.start_training(trainer)
75
76 def start_testing(self, trainer):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
109 def start_training(self, trainer: 'Trainer') -> None:
110 # double dispatch to initiate the training loop
--> 111 self._results = trainer.run_train()
112
113 def start_testing(self, trainer: 'Trainer') -> None:
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
643 with self.profiler.profile("run_training_epoch"):
644 # run train epoch
--> 645 self.train_loop.run_training_epoch()
646
647 if self.max_steps and self.max_steps <= self.global_step:
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
491 # ------------------------------------
492 with self.trainer.profiler.profile("run_training_batch"):
--> 493 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
494
495 # when returning -1 from train_step, we end epoch early
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
653
654 # optimizer step
--> 655 self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
656
657 else:
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
424
425 # model hook
--> 426 model_ref.optimizer_step(
427 self.trainer.current_epoch,
428 batch_idx,
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
1382 # wraps into LightingOptimizer only for running step
1383 optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx)
-> 1384 optimizer.step(closure=optimizer_closure)
1385
1386 def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
212 profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
213
--> 214 self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
215 self._total_optimizer_step_calls += 1
216
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
132
133 with trainer.profiler.profile(profiler_name):
--> 134 trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
135
136 def step(self, *args, closure: Optional[Callable] = None, **kwargs):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
276 )
277 if make_optimizer_step:
--> 278 self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
279 self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
280 self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
281
282 def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
--> 283 self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
284
285 def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
158
159 def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 160 optimizer.step(closure=lambda_closure, **kwargs)
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
87 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
88 with torch.autograd.profiler.record_function(profile_name):
---> 89 return func(*args, **kwargs)
90 return wrapper
91
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
25 def decorate_context(*args, **kwargs):
26 with self.__class__():
---> 27 return func(*args, **kwargs)
28 return cast(F, decorate_context)
29
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/optim/adam.py in step(self, closure)
64 if closure is not None:
65 with torch.enable_grad():
---> 66 loss = closure()
67
68 for group in self.param_groups:
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
647
648 def train_step_and_backward_closure():
--> 649 result = self.training_step_and_backward(
650 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
651 )
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
741 with self.trainer.profiler.profile("training_step_and_backward"):
742 # lightning module hook
--> 743 result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
744 self._curr_step_result = result
745
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
291 model_ref._results = Result()
292 with self.trainer.profiler.profile("training_step"):
--> 293 training_step_output = self.trainer.accelerator.training_step(args)
294 self.trainer.accelerator.post_training_step()
295
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
155
156 with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 157 return self.training_type_plugin.training_step(*args)
158
159 def post_training_step(self):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
120
121 def training_step(self, *args, **kwargs):
--> 122 return self.lightning_module.training_step(*args, **kwargs)
123
124 def post_training_step(self):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
346 if optimizer_idx == 1:
347 inference_inputs = self.module._get_inference_input(batch)
--> 348 outputs = self.module.inference(**inference_inputs)
349 z = outputs["z"]
350 loss = self.loss_adversarial_classifier(z.detach(), batch_tensor, True)
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
33
34 device = list(set(p.device for p in self.parameters()))
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/module/_totalvae.py in inference(self, x, y, batch_index, label, n_samples, transform_batch, cont_covs, cat_covs)
436 else:
437 categorical_input = tuple()
--> 438 qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
439 encoder_input, batch_index, *categorical_input
440 )
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/nn/_base_components.py in forward(self, data, *cat_list)
984 qz_m = self.z_mean_encoder(q)
985 qz_v = torch.exp(self.z_var_encoder(q)) + 1e-4
--> 986 z, untran_z = self.reparameterize_transformation(qz_m, qz_v)
987
988 ql_gene = self.l_gene_encoder(data, *cat_list)
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/nn/_base_components.py in reparameterize_transformation(self, mu, var)
950
951 def reparameterize_transformation(self, mu, var):
--> 952 untran_z = Normal(mu, var.sqrt()).rsample()
953 z = self.z_transformation(untran_z)
954 return z, untran_z
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
48 else:
49 batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
51
52 def expand(self, batch_shape, _instance=None):
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
51 continue # skip checking lazily-constructed args
52 if not constraint.check(getattr(self, param)).all():
---> 53 raise ValueError("The parameter {} has invalid values".format(param))
54 super(Distribution, self).__init__()
55
ValueError: The parameter loc has invalid values
This is the AnnData setup summary:
Anndata setup with scvi-tools version 0.9.0.
Data Summary
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃ Data ┃ Count ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│ Cells │ 35523 │
│ Vars │ 10620 │
│ Labels │ 1 │
│ Batches │ 6 │
│ Proteins │ 3 │
│ Extra Categorical Covariates │ 0 │
│ Extra Continuous Covariates │ 0 │
└──────────────────────────────┴───────┘
SCVI Data Registry
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Data ┃ scvi-tools Location ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ X │ adata.layers['counts'] │
│ batch_indices │ adata.obs['_scvi_batch'] │
│ local_l_mean │ adata.obs['_scvi_local_l_mean'] │
│ local_l_var │ adata.obs['_scvi_local_l_var'] │
│ labels │ adata.obs['_scvi_labels'] │
│ protein_expression │ adata.obsm['protein_expression'] │
└────────────────────┴──────────────────────────────────┘
Label Categories
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['_scvi_labels'] │ 0 │ 0 │
└───────────────────────────┴────────────┴─────────────────────┘
Batch Categories
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['batch'] │ PB_baseline │ 0 │
│ │ PB_primary_CMV │ 1 │
│ │ PB_steady_state │ 2 │
│ │ LN_steady_state │ 3 │
│ │ PB_CMV_rechallenge │ 4 │
│ │ LN_CMV_rechallenge │ 5 │
└────────────────────┴────────────────────┴─────────────────────┘
Any recommendations on what I might try? Thank in advance!