SOLO/scVI .train() error related to batch_size?

I am getting this error when i run vae.train() ultimately for SOLO

“ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128])”

I found a few posts on github, but i’m too daft to understand how to apply this to my code

Here’s a snapshot of my code:

    sample = sc.read_10x_mtx(mapped_data_location, cache = True)
    bdata = sample.copy()

    sc.pp.filter_genes(bdata, min_cells = 10)
    sc.pp.filter_cells(bdata, min_genes = 3)
    sc.pp.highly_variable_genes(bdata, n_top_genes = 2000, subset = True, flavor = 'seurat_v3')
    scvi.model.SCVI.setup_anndata(bdata)
    vae = scvi.model.SCVI(bdata)
    vae.train()

Ah yes, my guess is that this is occurring because the size of your dataset modulo the default batch size (128) is 1, so the last minibatch during an epoch only has one cell. Thus, the batch norm layer is complaining since it can’t compute normalization statistics on just one observation.

The simplest fix for this would be to be pass in some batch size other than 128 to train. Hope this helps!

That worked, thanks!!

1 Like

Hi! I’m getting this error (exact message below) even when n_cells%128 is not 1.

ValueError                                Traceback (most recent call last)
Cell In[28], line 7
      5 print((adata_subset[adata_subset.obs['10x_batch']==batch].n_obs)%128)
      6 solobatch = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch = batch)
----> 7 scvi.external.SOLO.train(solobatch, early_stopping = False,batch_size=256)#note changed from 128 to 64
      8 predictions = scvi.external.SOLO.predict(solobatch, soft = False)
      9 predictions_all.append(predictions)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/external/solo/_model.py:385, in SOLO.train(self, max_epochs, lr, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, datasplitter_kwargs, plan_kwargs, early_stopping, early_stopping_patience, early_stopping_min_delta, **kwargs)
    375 training_plan = ClassifierTrainingPlan(self.module, **plan_kwargs)
    376 runner = TrainRunner(
    377     self,
    378     training_plan=training_plan,
   (...)
    383     **kwargs,
    384 )
--> 385 return runner()

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainrunner.py:98, in TrainRunner.__call__(self)
     95 if hasattr(self.data_splitter, "n_val"):
     96     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 98 self.trainer.fit(self.training_plan, self.data_splitter)
     99 self._update_history()
    101 # data splitter only gets these attrs after fit

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainer.py:219, in Trainer.fit(self, *args, **kwargs)
    213 if isinstance(args[0], PyroTrainingPlan):
    214     warnings.filterwarnings(
    215         action="ignore",
    216         category=UserWarning,
    217         message="`LightningModule.configure_optimizers` returned `None`",
    218     )
--> 219 super().fit(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
    984 self._signal_connector.register_signal_handlers()
    986 # ----------------------------
    987 # RUN THE TRAINER
    988 # ----------------------------
--> 989 results = self._run_stage()
    991 # ----------------------------
    992 # POST-Training CLEAN UP
    993 # ----------------------------
    994 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
   1033         self._run_sanity_check()
   1034     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035         self.fit_loop.run()
   1036     return None
   1037 raise RuntimeError(f"Unexpected state {self.state}")

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
    200 try:
    201     self.on_advance_start()
--> 202     self.advance()
    203     self.on_advance_end()
    204     self._restarting = False

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
    357 with self.trainer.profiler.profile("run_training_epoch"):
    358     assert self._data_fetcher is not None
--> 359     self.epoch_loop.run(self._data_fetcher)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
    134 while not self.done:
    135     try:
--> 136         self.advance(data_fetcher)
    137         self.on_advance_end(data_fetcher)
    138         self._restarting = False

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
    237 with trainer.profiler.profile("run_training_batch"):
    238     if trainer.lightning_module.automatic_optimization:
    239         # in automatic optimization, there can only be one optimizer
--> 240         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    241     else:
    242         batch_output = self.manual_optimization.run(kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    180         closure()
    182 # ------------------------------
    183 # BACKWARD PASS
    184 # ------------------------------
    185 # gradient update with accumulated gradients
    186 else:
--> 187     self._optimizer_step(batch_idx, closure)
    189 result = closure.consume_result()
    190 if result.loss is None:

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    262     self.optim_progress.optimizer.step.increment_ready()
    264 # model hook
--> 265 call._call_lightning_module_hook(
    266     trainer,
    267     "optimizer_step",
    268     trainer.current_epoch,
    269     batch_idx,
    270     optimizer,
    271     train_step_and_backward_closure,
    272 )
    274 if not should_accumulate:
    275     self.optim_progress.optimizer.step.increment_completed()

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    154 pl_module._current_fx_name = hook_name
    156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157     output = fn(*args, **kwargs)
    159 # restore current_fx when nested context
    160 pl_module._current_fx_name = prev_fx_name

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1252 def optimizer_step(
   1253     self,
   1254     epoch: int,
   (...)
   1257     optimizer_closure: Optional[Callable[[], Any]] = None,
   1258 ) -> None:
   1259     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1260     the optimizer.
   1261 
   (...)
   1289 
   1290     """
-> 1291     optimizer.step(closure=optimizer_closure)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
    148     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    153 self._on_after_step()
    155 return step_output

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    115 """Hook to run the optimizer step."""
    116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/optim/optimizer.py:385, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    380         else:
    381             raise RuntimeError(
    382                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    383             )
--> 385 out = func(*args, **kwargs)
    386 self._optimizer_step_code()
    388 # call optimizer step post hooks

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     74     torch.set_grad_enabled(self.defaults['differentiable'])
     75     torch._dynamo.graph_break()
---> 76     ret = func(self, *args, **kwargs)
     77 finally:
     78     torch._dynamo.graph_break()

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/optim/adam.py:146, in Adam.step(self, closure)
    144 if closure is not None:
    145     with torch.enable_grad():
--> 146         loss = closure()
    148 for group in self.param_groups:
    149     params_with_grad = []

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
     91 def _wrap_closure(
     92     self,
     93     model: "pl.LightningModule",
     94     optimizer: Optimizer,
     95     closure: Callable[[], Any],
     96 ) -> Any:
     97     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
     98     hook is called.
     99 
   (...)
    102 
    103     """
--> 104     closure_result = closure()
    105     self._after_closure(model, optimizer)
    106     return closure_result

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
    124 @torch.enable_grad()
    125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126     step_output = self._step_fn()
    128     if step_output.closure_loss is None:
    129         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
    312 trainer = self.trainer
    314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    316 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    306     return None
    308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309     output = fn(*args, **kwargs)
    311 # restore current_fx when nested context
    312 pl_module._current_fx_name = prev_fx_name

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
    380 if self.model != self.lightning_module:
    381     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainingplans.py:1128, in ClassifierTrainingPlan.training_step(self, batch, batch_idx)
   1126 def training_step(self, batch, batch_idx):
   1127     """Training step for classifier training."""
-> 1128     soft_prediction = self.forward(batch[self.data_key])
   1129     loss = self.loss_fn(soft_prediction, batch[self.labels_key].view(-1).long())
   1130     self.log("train_loss", loss, on_epoch=True, prog_bar=True)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainingplans.py:1124, in ClassifierTrainingPlan.forward(self, *args, **kwargs)
   1122 def forward(self, *args, **kwargs):
   1123     """Passthrough to the module's forward function."""
-> 1124     return self.module(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/module/_classifier.py:78, in Classifier.forward(self, x)
     76 def forward(self, x):
     77     """Forward computation."""
---> 78     return self.classifier(x)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/nn/_base_components.py:170, in FCLayers.forward(self, x, *cat_list)
    168         x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0)
    169     else:
--> 170         x = layer(x)
    171 else:
    172     if isinstance(layer, nn.Linear) and self.inject_into_layer(i):

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:175, in _BatchNorm.forward(self, input)
    168     bn_training = (self.running_mean is None) and (self.running_var is None)
    170 r"""
    171 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
    172 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
    173 used for normalization (i.e. in eval mode when buffers are not None).
    174 """
--> 175 return F.batch_norm(
    176     input,
    177     # If buffers are not to be tracked, ensure that they won't be updated
    178     self.running_mean
    179     if not self.training or self.track_running_stats
    180     else None,
    181     self.running_var if not self.training or self.track_running_stats else None,
    182     self.weight,
    183     self.bias,
    184     bn_training,
    185     exponential_average_factor,
    186     self.eps,
    187 )

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/functional.py:2480, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
   2467     return handle_torch_function(
   2468         batch_norm,
   2469         (input, running_mean, running_var, weight, bias),
   (...)
   2477         eps=eps,
   2478     )
   2479 if training:
-> 2480     _verify_batch_size(input.size())
   2482 return torch.batch_norm(
   2483     input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
   2484 )

File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/functional.py:2448, in _verify_batch_size(size)
   2446     size_prods *= size[i + 2]
   2447 if size_prods == 1:
-> 2448     raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128])

after running:

    solobatch = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch = batch)
    scvi.external.SOLO.train(solobatch, early_stopping = False,batch_size=256)#note changed from 128 to 64
    predictions = scvi.external.SOLO.predict(solobatch, soft = False)

Any ideas for what else might throw it? I did just upgrade scvi-tools and my pytorch.

Thanks!

Resolved! Was an error with incompatible versions of cuda/jaxlib.