Hi! I’m getting this error (exact message below) even when n_cells%128 is not 1.
ValueError Traceback (most recent call last)
Cell In[28], line 7
5 print((adata_subset[adata_subset.obs['10x_batch']==batch].n_obs)%128)
6 solobatch = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch = batch)
----> 7 scvi.external.SOLO.train(solobatch, early_stopping = False,batch_size=256)#note changed from 128 to 64
8 predictions = scvi.external.SOLO.predict(solobatch, soft = False)
9 predictions_all.append(predictions)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/external/solo/_model.py:385, in SOLO.train(self, max_epochs, lr, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, datasplitter_kwargs, plan_kwargs, early_stopping, early_stopping_patience, early_stopping_min_delta, **kwargs)
375 training_plan = ClassifierTrainingPlan(self.module, **plan_kwargs)
376 runner = TrainRunner(
377 self,
378 training_plan=training_plan,
(...)
383 **kwargs,
384 )
--> 385 return runner()
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainrunner.py:98, in TrainRunner.__call__(self)
95 if hasattr(self.data_splitter, "n_val"):
96 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 98 self.trainer.fit(self.training_plan, self.data_splitter)
99 self._update_history()
101 # data splitter only gets these attrs after fit
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainer.py:219, in Trainer.fit(self, *args, **kwargs)
213 if isinstance(args[0], PyroTrainingPlan):
214 warnings.filterwarnings(
215 action="ignore",
216 category=UserWarning,
217 message="`LightningModule.configure_optimizers` returned `None`",
218 )
--> 219 super().fit(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
542 self.state.status = TrainerStatus.RUNNING
543 self.training = True
--> 544 call._call_and_handle_interrupt(
545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546 )
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
42 if trainer.strategy.launcher is not None:
43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
47 _call_teardown_hook(trainer)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
573 assert self.state.fn is not None
574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
575 self.state.fn,
576 ckpt_path,
577 model_provided=True,
578 model_connected=self.lightning_module is not None,
579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
582 assert self.state.stopped
583 self.training = False
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
984 self._signal_connector.register_signal_handlers()
986 # ----------------------------
987 # RUN THE TRAINER
988 # ----------------------------
--> 989 results = self._run_stage()
991 # ----------------------------
992 # POST-Training CLEAN UP
993 # ----------------------------
994 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
1033 self._run_sanity_check()
1034 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035 self.fit_loop.run()
1036 return None
1037 raise RuntimeError(f"Unexpected state {self.state}")
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
200 try:
201 self.on_advance_start()
--> 202 self.advance()
203 self.on_advance_end()
204 self._restarting = False
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
357 with self.trainer.profiler.profile("run_training_epoch"):
358 assert self._data_fetcher is not None
--> 359 self.epoch_loop.run(self._data_fetcher)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
134 while not self.done:
135 try:
--> 136 self.advance(data_fetcher)
137 self.on_advance_end(data_fetcher)
138 self._restarting = False
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
237 with trainer.profiler.profile("run_training_batch"):
238 if trainer.lightning_module.automatic_optimization:
239 # in automatic optimization, there can only be one optimizer
--> 240 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
241 else:
242 batch_output = self.manual_optimization.run(kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
180 closure()
182 # ------------------------------
183 # BACKWARD PASS
184 # ------------------------------
185 # gradient update with accumulated gradients
186 else:
--> 187 self._optimizer_step(batch_idx, closure)
189 result = closure.consume_result()
190 if result.loss is None:
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
262 self.optim_progress.optimizer.step.increment_ready()
264 # model hook
--> 265 call._call_lightning_module_hook(
266 trainer,
267 "optimizer_step",
268 trainer.current_epoch,
269 batch_idx,
270 optimizer,
271 train_step_and_backward_closure,
272 )
274 if not should_accumulate:
275 self.optim_progress.optimizer.step.increment_completed()
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
154 pl_module._current_fx_name = hook_name
156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157 output = fn(*args, **kwargs)
159 # restore current_fx when nested context
160 pl_module._current_fx_name = prev_fx_name
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
1252 def optimizer_step(
1253 self,
1254 epoch: int,
(...)
1257 optimizer_closure: Optional[Callable[[], Any]] = None,
1258 ) -> None:
1259 r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1260 the optimizer.
1261
(...)
1289
1290 """
-> 1291 optimizer.step(closure=optimizer_closure)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
148 raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
153 self._on_after_step()
155 return step_output
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
115 """Hook to run the optimizer step."""
116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/optim/optimizer.py:385, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
380 else:
381 raise RuntimeError(
382 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
383 )
--> 385 out = func(*args, **kwargs)
386 self._optimizer_step_code()
388 # call optimizer step post hooks
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
74 torch.set_grad_enabled(self.defaults['differentiable'])
75 torch._dynamo.graph_break()
---> 76 ret = func(self, *args, **kwargs)
77 finally:
78 torch._dynamo.graph_break()
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/optim/adam.py:146, in Adam.step(self, closure)
144 if closure is not None:
145 with torch.enable_grad():
--> 146 loss = closure()
148 for group in self.param_groups:
149 params_with_grad = []
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
91 def _wrap_closure(
92 self,
93 model: "pl.LightningModule",
94 optimizer: Optimizer,
95 closure: Callable[[], Any],
96 ) -> Any:
97 """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
98 hook is called.
99
(...)
102
103 """
--> 104 closure_result = closure()
105 self._after_closure(model, optimizer)
106 return closure_result
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140 self._result = self.closure(*args, **kwargs)
141 return self._result.loss
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
124 @torch.enable_grad()
125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126 step_output = self._step_fn()
128 if step_output.closure_loss is None:
129 self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
312 trainer = self.trainer
314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
316 self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
306 return None
308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309 output = fn(*args, **kwargs)
311 # restore current_fx when nested context
312 pl_module._current_fx_name = prev_fx_name
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
380 if self.model != self.lightning_module:
381 return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainingplans.py:1128, in ClassifierTrainingPlan.training_step(self, batch, batch_idx)
1126 def training_step(self, batch, batch_idx):
1127 """Training step for classifier training."""
-> 1128 soft_prediction = self.forward(batch[self.data_key])
1129 loss = self.loss_fn(soft_prediction, batch[self.labels_key].view(-1).long())
1130 self.log("train_loss", loss, on_epoch=True, prog_bar=True)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/train/_trainingplans.py:1124, in ClassifierTrainingPlan.forward(self, *args, **kwargs)
1122 def forward(self, *args, **kwargs):
1123 """Passthrough to the module's forward function."""
-> 1124 return self.module(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/module/_classifier.py:78, in Classifier.forward(self, x)
76 def forward(self, x):
77 """Forward computation."""
---> 78 return self.classifier(x)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/scvi/nn/_base_components.py:170, in FCLayers.forward(self, x, *cat_list)
168 x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0)
169 else:
--> 170 x = layer(x)
171 else:
172 if isinstance(layer, nn.Linear) and self.inject_into_layer(i):
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:175, in _BatchNorm.forward(self, input)
168 bn_training = (self.running_mean is None) and (self.running_var is None)
170 r"""
171 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
172 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
173 used for normalization (i.e. in eval mode when buffers are not None).
174 """
--> 175 return F.batch_norm(
176 input,
177 # If buffers are not to be tracked, ensure that they won't be updated
178 self.running_mean
179 if not self.training or self.track_running_stats
180 else None,
181 self.running_var if not self.training or self.track_running_stats else None,
182 self.weight,
183 self.bias,
184 bn_training,
185 exponential_average_factor,
186 self.eps,
187 )
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/functional.py:2480, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
2467 return handle_torch_function(
2468 batch_norm,
2469 (input, running_mean, running_var, weight, bias),
(...)
2477 eps=eps,
2478 )
2479 if training:
-> 2480 _verify_batch_size(input.size())
2482 return torch.batch_norm(
2483 input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
2484 )
File ~/anaconda3/envs/scvi-scanpy/lib/python3.9/site-packages/torch/nn/functional.py:2448, in _verify_batch_size(size)
2446 size_prods *= size[i + 2]
2447 if size_prods == 1:
-> 2448 raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128])
after running:
solobatch = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch = batch)
scvi.external.SOLO.train(solobatch, early_stopping = False,batch_size=256)#note changed from 128 to 64
predictions = scvi.external.SOLO.predict(solobatch, soft = False)
Any ideas for what else might throw it? I did just upgrade scvi-tools and my pytorch.
Thanks!