Label Transfer Discrepancy in scANVI Model Training

Hello All,

I am currently working on training a model using scANVI, and I have encountered a puzzling situation regarding label transfer. Here’s a brief overview:

  • Reference data: 200,000 cells with 3 labels.
  • Query datasets: Three sets with varying sizes (500,000 cells, and two with close to 800,000 cells each).

The issue arises when I perform label transfer using the same set of parameters for the three different query datasets separately. The reference prediction accuracy varies across these datasets. Here’s an example to illustrate:

  1. Query data 1:
  • Class label 1: 0.98
  • Class label 2: 0.94
  • Class label 3: 0.99
  1. Query data 2:
  • Class label 1: 0.88
  • Class label 2: 0.92
  • Class label 3: 0.91
  1. Query data 3:
  • Class label 1: 0.77
  • Class label 2: 0.89
  • Class label 3: 0.88

My question is, why am I observing different reference label prediction accuracies while using different datasets? I am using the same genes and parameters for prediction in all cases.

I have also attached the pseudocode for your reference. Any insights or assistance on this matter would be so helpful.

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
from scvi.model.utils import mde
import scanpy.external as sce
import anndata
from scipy.io import mmread
from sklearn.decomposition import TruncatedSVD
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import torch

# Loading functions

scvi.settings.seed = 0  
    
#######################################
###COMMANDS FOR THE SPECIFIC ANALSIS###
#######################################

# Setting the batch key name. 

batch_key_name = "ds_ID"

layer_name = "counts"

# Setting saving path.

dir_path = "/MTG_tau/test1/"

name_of_the_new_col = "celltype_scanvi"

# Loading query data.

data=anndata.read_h5ad("/MTG_tau/test1/tau.h5ad")

adata=data

adata.X=adata.layers["counts"]


adata.var_names_make_unique()
adata.obs_names_make_unique()

sc.pp.normalize_total(adata, target_sum=1e4)

sc.pp.log1p(adata)

# Setup the AnnData object for scvi-tools
scvi.model.SCVI.setup_anndata(adata, layer="counts",batch_key=combined_batch_key)


vae = scvi.model.SCVI(adata, n_layers=2, n_latent=30)

# Training the model.

vae.train(max_epochs=400)

adata.obsm["X_scVI"] = vae.get_latent_representation()

vae.save("tau_reference_model", overwrite = True)

#adata.obsm["X_mde"] = mde(adata.obsm["X_scVI"])

adata.obs = adata.obs.astype(str)

adata.write("/MTG_tau/test1/reference_train.h5ad")
# Preparing object for label transfer.

adata.obs[name_of_the_new_col] = "Unknown"

# Specifying the label transfer model.

ref_mask = adata.obs["ds_ID"] == "ref"

adata.obs[name_of_the_new_col][ref_mask] = adata.obs.SORT[ref_mask].values

lvae = scvi.model.SCANVI.from_scvi_model(vae,adata=adata,unlabeled_category = "Unknown", labels_key = name_of_the_new_col,linear_classifier=True,var_activation=torch.nn.functional.softplus)

# Training the model.

lvae.train(max_epochs=200,n_samples_per_label=1000)

lvae.save("tau_tune_model", overwrite=True)

softprediction = lvae.predict(adata,soft=True)

softprediction.to_csv("/MTG_tau/test1/soft_pred.csv")

adata.obsm["X_scANVI"] = lvae.get_latent_representation(adata)

adata.obsm["X_mde_scanvi"] = mde(adata.obsm["X_scANVI"])

# Saving anndata and model.

adata.obs = adata.obs.astype(str)

adata.write(dir_path + "/" + "adata_lvae.h5ad")

### find soft predict value and confidence value 
import pandas as pd

data = pd.read_csv("/MTG_tau/test1/soft_pred.csv")
df = pd.DataFrame(data)

# Find maximum values for numeric columns
numeric_cols = df.select_dtypes(include='number').columns
max_values = df[numeric_cols].max(axis=1)

# Find column name for maximum values
max_columns = df[numeric_cols].idxmax(axis=1)

# Add information to the dataframe
df['max_value'] = max_values
df['max_column'] = max_columns

df.to_csv("/MTG_tau/test1/prediction_conf.csv")

Maybe I’m not understanding your code correctly, but it looks like you’re setting all labels to "Unknown" in the reference data. I am referring to this line:

adata.obs[name_of_the_new_col] = "Unknown"

Could you clarify what’s going on here?

To also clarify my understanding of the issue, are you doing additional training with the query data and then performing prediction on the reference dataset?

Hello,

I would like to clarify the following code snippet:


name_of_the_new_col = "celltype_scanvi"

adata.obs[name_of_the_new_col] = "Unknown"

In this part of the code, I am creating a new column named “celltype_scanvi” in the AnnData object (adata) to store prediction labels. Initially, before training, all labels in this column are set to “Unknown.” After training and obtaining prediction results, this column will be updated with the predicted labels.


ref_mask = adata.obs["ds_ID"] == "ref"

adata.obs[name_of_the_new_col][ref_mask] = adata.obs.SORT[ref_mask].values

Following that, a Boolean mask (ref_mask) is created to filter rows where the “ds_ID” is equal to “ref.” Subsequently, the values in the “celltype_scanvi” column for these rows are replaced with the corresponding values from the “SORT” column.

After training, the prediction probabilities are saved using the “soft” function:


softprediction = lvae.predict(adata, soft=True)

This generates a “max_column” representing the class label with the highest probability for each cell.

When calculating the confusion matrix for reference data based on this “max_column” and the original class labels (“SORT”) for three different iterations mentioned above, the accuracy matrix differs for each iteration, despite using the same parameters.

I hope this clarifies any confusion.

Please provide assistance or insights

confusion matrix calculation:

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
## subset the reference data
ref=ref[ref.obs["ds_ID"].isin(["ref"])]

## select the max_column after prediction and original class label from "SORT"
df = ref.obs
confusion_matrix = pd.crosstab(
    df["max_column"],
    df["SORT"],
    rownames=["max_column"],
    colnames=["SORT"],
)
confusion_matrix /= confusion_matrix.sum(1).ravel().reshape(-1, 1)

# Reorder the confusion matrix based on the row labels
order = confusion_matrix.index.sort_values()
confusion_matrix = confusion_matrix.reindex(index=order, columns=order)

fig, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
    confusion_matrix,
    cmap=sns.diverging_palette(245, 320, s=60, as_cmap=True),
    ax=ax,
    square=True,
    cbar_kws=dict(shrink=0.4, aspect=12),
)

# Label confidence scores above 0.5
for i in range(confusion_matrix.shape[0]):
    for j in range(confusion_matrix.shape[1]):
        if confusion_matrix.iloc[i, j] > 0.5:
            ax.text(j + 0.5, i + 0.5, f"{confusion_matrix.iloc[i, j]:.2f}",
                    ha='center', va='center', color='white', fontsize=14)

Thank you.