Hi scvi-tools Team,
I have been trying out scVI and scANVI and am evaluating how good label transfer works. For my current test setup, I have 4 different datasets (human liver), which I have manually labeled. To test scANVI, I overwrite the labels for one dataset with “Unknown”. I pretty much just follow your tutorial for scANVI.
The label transfer works pretty nice for the dataset with the “Unknown” cell types. However, when I use the SCANVI.predict()
function, it generates wrong labels for actually known cell types (so not marked as “Unknown”). And bad ones at that.
Now my question, is that expected, i.e. is scANVI supposed to predict labels for these cells as well? And the more difficult question, any idea why it behaves this way? I’m still assuming that I’m just doing something wrong but I can’t figure out what. I have tried both starting from a pre-trained scVI model and training a scANVI model from scratch. Any help would be highly appreciated!
I’ll put code and a figure below. It’s pretty clear when looking at the NKT cluster in celltype_scanvi
and then comparing to the same cluster in C_scANVI
. This cluster doesn’t contain unlabeled cells.
Cheers,
Kevin
adata.obs["celltype_scanvi"] = 'Unknown'
# Get the labels for datasets 0, 1, 2
batch_idx = adata.obs['batch'] == "0"
adata.obs["celltype_scanvi"][batch_idx] = adata.obs.celltype[batch_idx]
batch_idx = adata.obs['batch'] == "1"
adata.obs["celltype_scanvi"][batch_idx] = adata.obs.celltype[batch_idx]
batch_idx = adata.obs['batch'] == "2"
adata.obs["celltype_scanvi"][batch_idx] = adata.obs.celltype[batch_idx]
adata.obs['celltype_scanvi'] = adata.obs['celltype_scanvi'].astype("str")
np.unique(adata.obs["celltype_scanvi"], return_counts=True)
scvi.data.setup_anndata(
adata,
layer="counts",
batch_key="batch",
labels_key="celltype_scanvi",
)
lvae = scvi.model.SCANVI(adata, "Unknown", n_latent=30, n_layers=2)
lvae.train(n_samples_per_label=100)
adata.obs["C_scANVI"] = lvae.predict(adata)
adata.obsm["X_scANVI"] = lvae.get_latent_representation(adata)
sc.pp.neighbors(adata, use_rep="X_scANVI")
sc.tl.umap(adata)
sc.pl.umap(adata, color=["celltype_scanvi", "C_scANVI", "batch"], ncols=1, frameon=False)