Heatmap cannot be assigned as a subplot

Hi,

as per the documentation, it is possible to assign plots created by scanpy to matplotlib subplots. However, this does not work for heatmaps, since they internally use multiple axes to construct the plot. This minimal example shows the issue:

import scanpy as sc
import matplotlib.pyplot as plt

# loading data for minimal example
pbmc = sc.datasets.pbmc68k_reduced()
marker_genes_dict = {
    'B-cell': ['CD79A', 'MS4A1'],
    'Dendritic': ['FCER1A', 'CST3'],
    'Monocytes': ['FCGR3A'],
    'NK': ['GNLY', 'NKG7'],
    'Other': ['IGLL1'],
    'Plasma': ['IGJ'],
    'T-cell': ['CD3D'],
}

# plotting the heatmap/matrixplot individually
sc.pl.heatmap(pbmc, marker_genes_dict, groupby='bulk_labels', cmap='viridis')
sc.pl.matrixplot(pbmc, marker_genes_dict, groupby='bulk_labels', cmap='viridis')

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,4), gridspec_kw={'wspace':0.9})

ax1_dict = sc.pl.dotplot(pbmc, marker_genes_dict, groupby='bulk_labels', ax=ax1, show=False)
ax2_dict = sc.pl.stacked_violin(pbmc, marker_genes_dict, groupby='bulk_labels', ax=ax2, show=False)
# this works
# ax3_dict = sc.pl.matrixplot(pbmc, marker_genes_dict, groupby='bulk_labels', ax=ax3, show=False, cmap='viridis')
# this doesn't
ax3_dict = sc.pl.heatmap(pbmc, marker_genes_dict, groupby='bulk_labels', ax=ax3, show=False, cmap='viridis')

Running this code throws TypeError: Axes.imshow() got multiple values for argument 'ax'. Is there a recommended workaround for this? It would be really useful to be able to have heatmaps as subplots.

Thank you very much!

I’ve been having this issue as well. It currently seems impossible to assign heatmaps to a matplotlib subplot.