It seems like mudata implements the required methods of the torch.dataset:
print(mdata.__len__())
print(mdata.__getitem__(0))
train_loader = DataLoader(
mdata,
batch_size=batch_size,
shuffle=True
)
len(train_loader)
gives
11909
View of MuData object with n_obs × n_vars = 1 × 144978
var: 'gene_ids', 'feature_types', 'genome', 'interval'
2 modalities
rna: 1 x 36601
var: 'gene_ids', 'feature_types', 'genome', 'interval', 'gene_symbol'
atac: 1 x 108377
var: 'gene_ids', 'feature_types', 'genome', 'interval'
uns: 'atac', 'files'
187
However, the loader does not iterate, next(iter(train_loader))
gives:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[65], line 1
----> 1 next(iter(train_loader))
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self)
625 if self._sampler_iter is None:
626 # TODO(https://github.com/pytorch/pytorch/issues/76750)
627 self._reset() # type: ignore[call-arg]
--> 628 data = self._next_data()
629 self._num_yielded += 1
630 if self._dataset_kind == _DatasetKind.Iterable and \
631 self._IterableDataset_len_called is not None and \
632 self._num_yielded > self._IterableDataset_len_called:
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:671, in _SingleProcessDataLoaderIter._next_data(self)
669 def _next_data(self):
670 index = self._next_index() # may raise StopIteration
--> 671 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
672 if self._pin_memory:
673 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:61, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
59 else:
60 data = self.dataset[possibly_batched_index]
---> 61 return self.collate_fn(data)
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py:265, in default_collate(batch)
204 def default_collate(batch):
205 r"""
206 Function that takes in a batch of data and puts the elements within the batch
207 into a tensor with an additional outer dimension - batch size. The exact output type can be
(...)
263 >>> default_collate(batch) # Handle `CustomType` automatically
264 """
--> 265 return collate(batch, collate_fn_map=default_collate_fn_map)
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py:151, in collate(batch, collate_fn_map)
147 except TypeError:
148 # The sequence type may not support `__init__(iterable)` (e.g., `range`).
149 return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
--> 151 raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found