AnnLoader for mudata?

It seems like mudata implements the required methods of the torch.dataset:

print(mdata.__len__())
print(mdata.__getitem__(0))
train_loader = DataLoader(
    mdata,
    batch_size=batch_size,
    shuffle=True
)

len(train_loader)

gives

11909
View of MuData object with n_obs × n_vars = 1 × 144978
  var:	'gene_ids', 'feature_types', 'genome', 'interval'
  2 modalities
    rna:	1 x 36601
      var:	'gene_ids', 'feature_types', 'genome', 'interval', 'gene_symbol'
    atac:	1 x 108377
      var:	'gene_ids', 'feature_types', 'genome', 'interval'
      uns:	'atac', 'files'

187

However, the loader does not iterate, next(iter(train_loader)) gives:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[65], line 1
----> 1 next(iter(train_loader))

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self)
    625 if self._sampler_iter is None:
    626     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627     self._reset()  # type: ignore[call-arg]
--> 628 data = self._next_data()
    629 self._num_yielded += 1
    630 if self._dataset_kind == _DatasetKind.Iterable and \
    631         self._IterableDataset_len_called is not None and \
    632         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:671, in _SingleProcessDataLoaderIter._next_data(self)
    669 def _next_data(self):
    670     index = self._next_index()  # may raise StopIteration
--> 671     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672     if self._pin_memory:
    673         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:61, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     59 else:
     60     data = self.dataset[possibly_batched_index]
---> 61 return self.collate_fn(data)

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py:265, in default_collate(batch)
    204 def default_collate(batch):
    205     r"""
    206         Function that takes in a batch of data and puts the elements within the batch
    207         into a tensor with an additional outer dimension - batch size. The exact output type can be
   (...)
    263             >>> default_collate(batch)  # Handle `CustomType` automatically
    264     """
--> 265     return collate(batch, collate_fn_map=default_collate_fn_map)

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py:151, in collate(batch, collate_fn_map)
    147         except TypeError:
    148             # The sequence type may not support `__init__(iterable)` (e.g., `range`).
    149             return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
--> 151 raise TypeError(default_collate_err_msg_format.format(elem_type))

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found 

This seems to work:

def sparse_csr_to_tensor(csr:csr_matrix):
    """
    Transform scipy csr matrix to pytorch sparse tensor
    """

    values = csr.data
    indices = np.vstack(csr.nonzero())
    shape = csr.shape

    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    s = torch.Size(shape)

    return torch.sparse.FloatTensor(i, v, s)
    
def sparse_batch_collate(batch:list):
    """
    Collate function to transform anndata csr view to pytorch sparse tensor
    """
    if type(batch[0]['atac'].X) == anndata._core.views.SparseCSRView:
        atac_batch = sparse_csr_to_tensor(vstack([x['atac'].X for x in batch]))
    else:
        atac_batch = torch.FloatTensor(vstack([x['rna'].X for x in batch]))

    if type(batch[0]['rna'].X) == anndata._core.views.SparseCSRView:
        rna_batch = sparse_csr_to_tensor(vstack([x['rna'].X for x in batch]))
    else:
        rna_batch = torch.FloatTensor(vstack([x['rna'].X for x in batch]))
    return atac_batch, rna_batch


loader = DataLoader(
    mdata,
    batch_size=10,
    collate_fn = sparse_data.sparse_batch_collate,
)