Here is the error message from before
ValueError Traceback (most recent call last)
<ipython-input-29-3de112adfeee> in <module>
----> 1 vae5.train()
/opt/conda/lib/python3.8/site-packages/scvi/model/_totalvi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, early_stopping, check_val_every_n_epoch, reduce_lr_on_plateau, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_classifier, plan_kwargs, **kwargs)
257 **kwargs,
258 )
--> 259 return runner()
260
261 @torch.no_grad()
/opt/conda/lib/python3.8/site-packages/scvi/train/_trainrunner.py in __call__(self)
73 self.trainer.fit(self.training_plan, train_dl)
74 else:
---> 75 self.trainer.fit(self.training_plan, train_dl, val_dl)
76 try:
77 self.model.history_ = self.trainer.logger.history
/opt/conda/lib/python3.8/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
150 message="you defined a validation_step but have no val_dataloader",
151 )
--> 152 super().fit(*args, **kwargs)
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
496
497 # dispath `start_training` or `start_testing` or `start_predicting`
--> 498 self.dispatch()
499
500 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
543
544 else:
--> 545 self.accelerator.start_training(self)
546
547 def train_or_test_or_predict(self):
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
71
72 def start_training(self, trainer):
---> 73 self.training_type_plugin.start_training(trainer)
74
75 def start_testing(self, trainer):
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
112 def start_training(self, trainer: 'Trainer') -> None:
113 # double dispatch to initiate the training loop
--> 114 self._results = trainer.run_train()
115
116 def start_testing(self, trainer: 'Trainer') -> None:
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
634 with self.profiler.profile("run_training_epoch"):
635 # run train epoch
--> 636 self.train_loop.run_training_epoch()
637
638 if self.max_steps and self.max_steps <= self.global_step:
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
491 # ------------------------------------
492 with self.trainer.profiler.profile("run_training_batch"):
--> 493 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
494
495 # when returning -1 from train_step, we end epoch early
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
653
654 # optimizer step
--> 655 self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
656
657 else:
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
424
425 # model hook
--> 426 model_ref.optimizer_step(
427 self.trainer.current_epoch,
428 batch_idx,
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
1383 # wraps into LightingOptimizer only for running step
1384 optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx)
-> 1385 optimizer.step(closure=optimizer_closure)
1386
1387 def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
212 profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
213
--> 214 self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
215 self._total_optimizer_step_calls += 1
216
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
132
133 with trainer.profiler.profile(profiler_name):
--> 134 trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
135
136 def step(self, *args, closure: Optional[Callable] = None, **kwargs):
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
275 )
276 if make_optimizer_step:
--> 277 self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
278 self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
279 self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
280
281 def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
--> 282 self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
283
284 def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
161
162 def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 163 optimizer.step(closure=lambda_closure, **kwargs)
/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
87 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
88 with torch.autograd.profiler.record_function(profile_name):
---> 89 return func(*args, **kwargs)
90 return wrapper
91
/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
25 def decorate_context(*args, **kwargs):
26 with self.__class__():
---> 27 return func(*args, **kwargs)
28 return cast(F, decorate_context)
29
/opt/conda/lib/python3.8/site-packages/torch/optim/adam.py in step(self, closure)
64 if closure is not None:
65 with torch.enable_grad():
---> 66 loss = closure()
67
68 for group in self.param_groups:
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
647
648 def train_step_and_backward_closure():
--> 649 result = self.training_step_and_backward(
650 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
651 )
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
741 with self.trainer.profiler.profile("training_step_and_backward"):
742 # lightning module hook
--> 743 result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
744 self._curr_step_result = result
745
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
291 model_ref._results = Result()
292 with self.trainer.profiler.profile("training_step"):
--> 293 training_step_output = self.trainer.accelerator.training_step(args)
294 self.trainer.accelerator.post_training_step()
295
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
154
155 with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 156 return self.training_type_plugin.training_step(*args)
157
158 def post_training_step(self):
/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
123
124 def training_step(self, *args, **kwargs):
--> 125 return self.lightning_module.training_step(*args, **kwargs)
126
127 def post_training_step(self):
/opt/conda/lib/python3.8/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
346 if optimizer_idx == 1:
347 inference_inputs = self.module._get_inference_input(batch)
--> 348 outputs = self.module.inference(**inference_inputs)
349 z = outputs["z"]
350 loss = self.loss_adversarial_classifier(z.detach(), batch_tensor, True)
/opt/conda/lib/python3.8/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
33
34 device = list(set(p.device for p in self.parameters()))
/opt/conda/lib/python3.8/site-packages/scvi/module/_totalvae.py in inference(self, x, y, batch_index, label, n_samples, transform_batch, cont_covs, cat_covs)
436 else:
437 categorical_input = tuple()
--> 438 qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
439 encoder_input, batch_index, *categorical_input
440 )
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
/opt/conda/lib/python3.8/site-packages/scvi/nn/_base_components.py in forward(self, data, *cat_list)
984 qz_m = self.z_mean_encoder(q)
985 qz_v = torch.exp(self.z_var_encoder(q)) + 1e-4
--> 986 z, untran_z = self.reparameterize_transformation(qz_m, qz_v)
987
988 ql_gene = self.l_gene_encoder(data, *cat_list)
/opt/conda/lib/python3.8/site-packages/scvi/nn/_base_components.py in reparameterize_transformation(self, mu, var)
950
951 def reparameterize_transformation(self, mu, var):
--> 952 untran_z = Normal(mu, var.sqrt()).rsample()
953 z = self.z_transformation(untran_z)
954 return z, untran_z
/opt/conda/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
48 else:
49 batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
51
52 def expand(self, batch_shape, _instance=None):
/opt/conda/lib/python3.8/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
51 continue # skip checking lazily-constructed args
52 if not constraint.check(getattr(self, param)).all():
---> 53 raise ValueError("The parameter {} has invalid values".format(param))
54 super(Distribution, self).__init__()
55
ValueError: The parameter loc has invalid values