Predict cell type with scANVI for spatial transcriptomics data (Xenium)

Hi,

Thanks so much for the amazing tools!

I was hoping someone could help me. I have an in-house dataset of Xenium 5k data and I would like to predict the cell types using a reference dataset.

As such, I am also trying to use scANVI to perform label transfer to improve the accuracy. However, I have a few questions and problems.

Questions for scANVI:

  1. Should I be subsetting my reference data on just the HVG to improve the accuracy?
  2. Should I be subseting my reference data to only include the genes also found in my STx experiment?
  3. How the difference in raw counts for the scRNAseq reference and raw counts for the STx that have different expression/number of genes impacting the model? Can the model handle this? How can I adjust for these differences?

Is there a better tool or approach I should be using instead of scANVI? I also tried sc.tl.ingest but got over presentation of rare cell populations.

Thanks so much!

Hey,

  1. yes
  2. no, that will be “peeking” into the query set (in such a case, you could just train the query set)
  3. This is what is scarches is all about, and then based on the trained reference model, you “fine-tune” it over the query data, despite their differences.

I suggest following this manual: Reference mapping with SCVI-Tools — scvi-tools

However, for more modern spatially methods, you might also try RESOLVI: ResolVI to address noise and biases in spatial transcriptomics — scvi-tools which works on Xenium data as well

Thanks so much @ori-kron-wis!

I had read about those tools but am not sure if I am implementing them correctly. If this something you have expertise, would you willing to advise if I post the code I’m using?

you can post, up to you

def scVI_integration(config, adata_ref, adata, module_dir):
    """Integrates STx data with a reference scRNA-seq dataset using scVI.

    This function performs harmonizes a reference single-cell dataset
    (e.g., HLCA) to a STx dataset (e.g., Xenium).

    Args:
        config (SimpleNamespace or dict): Configuration object containing
            module-specific parameters, including the module name used to
            construct figure filenames.
        adata_ref (anndata.AnnData): Reference scRNA-seq AnnData object containing
            precomputed embeddings (PCA or UMAP) and cell-type annotations.
            Should have matching genes with adata_ingest.
        adata (anndata.AnnData): Spatial dataset formatted for integration,
            aligned by gene set with the reference (same genes, same order).
        module_dir (Path): Output directory for saving results

    Returns:
        tuple: (adata_combined, scanvi_model) - Combined dataset and trained model
    """
    logger.info("Integrating data using scVI...")

    # Training parameters
    MAX_EPOCHS_SCVI = 200

    # Log available layers for debugging
    logger.info(f"Reference layers: {list(adata_ref.layers.keys())}")
    logger.info(f"Spatial layers: {list(adata.layers.keys())}")

    # Log counts layer shapes (already verified to exist)
    logger.info(f"Reference counts shape: {adata_ref.layers['counts'].shape}")
    logger.info(f"Spatial counts shape: {adata.layers['counts'].shape}")

    # Ensure REF_CELL_LABEL_COL column exists in reference
    adata.obs[REF_CELL_LABEL_COL] = (
        "STx_UNKNOWN"  # ensure column exists in spatial data
    )

    logger.info("Verifying datasets for scVI integration...")
    scVI_integration_check(adata_ref, batch_key=BATCH_COL, cell_type=REF_CELL_LABEL_COL)
    scVI_integration_check(adata, batch_key=BATCH_COL, cell_type=REF_CELL_LABEL_COL)

    # Combine datasets
    logger.info("Combining reference and spatial data...")
    adata_combined = anndata.concat(
        [adata_ref, adata],
        join="inner",  # only keeps genes present in both datasets
        label=BATCH_COL,
        keys=["Ref", "STx"],
        index_unique="_",
    )

    # Verify counts layer was preserved during concat
    if "counts" not in adata_combined.layers:
        logger.error("Counts layer lost during concat! Manually preserving...")
        # Manually recreate counts layer from original data
        ref_mask = adata_combined.obs[BATCH_COL] == "Ref"
        stx_mask = adata_combined.obs[BATCH_COL] == "STx"

        # Initialize counts layer
        adata_combined.layers["counts"] = adata_combined.X.copy()  # temporary

        # Fill with actual counts
        adata_combined.layers["counts"][ref_mask, :] = adata_ref.layers["counts"]
        adata_combined.layers["counts"][stx_mask, :] = adata.layers["counts"]
        logger.info("Counts layer manually restored!")
    else:
        logger.info("âś“ Counts layer preserved during concat")

    logger.info(f"Combined data layers: {list(adata_combined.layers.keys())}")

    # Setup scVI
    logger.info("Setting up scVI model...")
    SCVI.setup_anndata(
        adata_combined,
        layer="counts",  # must be counts layer
        batch_key=BATCH_COL,  # variable you want to perform harmonization over
    )

    logger.info("Initializing scVI model...")
    scvi_model = SCVI(adata_combined, n_layers=2, n_latent=30, n_hidden=128)
    logger.info("Training SCVI model...")
    scvi_model.train(max_epochs=MAX_EPOCHS_SCVI, batch_size=128)

    logger.info("Obtain and visualize latent representation...")
    adata_combined.obsm[SCVI_LATENT_KEY] = scvi_model.get_latent_representation()
    sc.pp.pca(adata_combined)
    sc.pp.neighbors(adata_combined, use_rep=SCVI_LATENT_KEY)
    sc.tl.umap(adata_combined)
    sc.pl.umap(
        adata_combined,
        color=[BATCH_COL, REF_CELL_LABEL_COL],
        frameon=False,
        ncols=1,
        save=f"_{config.module_name}_scvi_umap.png",
    )

    logger.info("Saving scANVI model...")
    scvi_model.save(module_dir / "_scvi_ref", overwrite=True)

    return adata_combined, scvi_model 
def scANVI_label_transfer(config, adata_combined, scvi_model, module_dir):
    """Integrates STx data with a reference scRNA-seq dataset using scVI.

    This function performs label transfer from a reference single-cell dataset
    (e.g., HLCA) to a STx dataset (e.g., Xenium).
    scANVI uses a semi-supervised approach to transfer cell-type labels.

    Args:
        adata_combined (anndata.AnnData): Reference scRNA-seq AnnData object containing
            precomputed embeddings (PCA or UMAP) and cell-type annotations.
            Should have matching genes with adata_ingest.
        scvi_model (scvi.model.SCVI): Trained scVI model on combined data.
        config (SimpleNamespace or dict): Configuration object containing
        module_dir (Path): Output directory for saving results

    Returns:
        tuple: (adata_combined, scanvi_model) - Combined dataset and trained model
    """
    # Set constant variables
    MAX_EPOCHS_SCANVI = 200

    # Format labels for scANVI
    # Create scANVI labels (reference has labels, spatial is 'Unknown')
    SCANVI_CELLTYPE_KEY = (
        "training_celltype_scanvi"  # new column that scANVI will use training labels.
    )
    UNLABELED_CATEGORY = (
        "Unknown"  # placeholder for cells that do not have known labels aka STx cells
    )

    # Set the label column (SCANVI_CELLTYPE_KEY) to "Unknown" for all cells initially.
    # This ensures that STx cells are marked as unlabeled before we assign ref labels
    adata_combined.obs[SCANVI_CELLTYPE_KEY] = UNLABELED_CATEGORY

    # Check columns in adata_combined
    logger.info(f"Columns in adata_combined: {list(adata_combined.obs.columns)}")

    # Ensure the batch column exists
    if BATCH_COL not in adata_combined.obs.columns:
        logger.error(f"Batch column '{BATCH_COL}' not found")
        return None, None

    # Assign real labels to reference cells
    ref_mask = adata_combined.obs[BATCH_COL] == "Ref"

    # Ensure reference label column exists
    if REF_CELL_LABEL_COL not in adata_combined.obs.columns:
        logger.error(f"Cell type column '{REF_CELL_LABEL_COL}' not found")
        return None, None

    # Assign reference labels safely
    adata_combined.obs.loc[ref_mask, SCANVI_CELLTYPE_KEY] = adata_combined.obs.loc[
        ref_mask, REF_CELL_LABEL_COL
    ].astype(str)

    # Convert the column to categorical (recommended for scANVI)
    adata_combined.obs[SCANVI_CELLTYPE_KEY] = adata_combined.obs[
        SCANVI_CELLTYPE_KEY
    ].astype("category")

    logger.info(
        f"Reference cells with labels: {ref_mask.sum()}"
    )  # number of reference cells with labels.
    logger.info(
        f"Percent ref cells labeled: {100 * ref_mask.sum() / adata_combined.n_obs:.2f}%"
    )
    logger.info(
        f"Spatial cells (unlabeled): {(~ref_mask).sum()}"
    )  # number of spatial (unlabeled) cells.
    logger.info(
        f"Percent Spatial cells: {100 * (~ref_mask).sum() / adata_combined.n_obs:.2f}%"
    )
    logger.info(
        f"Unique cell types: {adata_combined.obs[SCANVI_CELLTYPE_KEY].value_counts()}"
    )

    # Initialize scANVI from trained scVI
    logger.info("Initializing scANVI model...")
    SCANVI.setup_anndata(
        adata_combined,
        labels_key=SCANVI_CELLTYPE_KEY,
        unlabeled_category=UNLABELED_CATEGORY,
        batch_key=BATCH_COL,
    )
    scanvi_model = SCANVI.from_scvi_model(
        scvi_model,
        unlabeled_category=UNLABELED_CATEGORY,
        labels_key=SCANVI_CELLTYPE_KEY,
    )

    # Train scANVI
    logger.info("Training scANVI model...")
    scanvi_model.train(max_epochs=MAX_EPOCHS_SCANVI, batch_size=128)
    logger.info("scANVI training completed!")

    logger.info("Get latent representation...")
    adata_combined.obsm[SCANVI_LATENT_KEY] = scanvi_model.get_latent_representation()

    logger.info("Visualizing scANVI latent space...")
    sc.pp.pca(adata_combined)
    sc.pp.neighbors(adata_combined, use_rep=SCANVI_LATENT_KEY)
    sc.tl.umap(adata_combined)
    sc.pl.umap(
        adata_combined,
        color=[BATCH_COL, REF_CELL_LABEL_COL],
        frameon=False,
        ncols=1,
        save=f"_{config.module_name}_scanvi_umap.png",
    )

    logger.info("Saving scANVI model...")
    scanvi_model.save(module_dir / "_scanvi_ref", overwrite=True)

    return adata_combined, scanvi_model

def run_integration(config: IntegrateModuleConfig, io_config: IOConfig):
    """Integrate scRNAseq and STx data using scANVI and ingest.

    adata_ref_subset and adata_ingest are used for ingest integration.
    adata_ref and data are used for scANVI integration.

    Args:
        config (IntegrateModuleConfig): Integration module configuration object.
        io_config (IOConfig): IO configuration object.

    Returns:
        None
    """
    # Variables

    # Name of the column to store label transfer results in adata.obs
    module_dir = io_config.output_dir / config.module_name

    # Paths to input data
    ref_path = io_config.ref_path
    gene_id_dict_path = io_config.gene_id_dict_path

    # Create output directories if they do not exist
    module_dir.mkdir(exist_ok=True)

    # Set figure directory for this module (overrides global setting)
    sc.settings.figdir = module_dir

    # Get shared colormap from global visualization settings
    # This ensures consistency across all modules
    viz_assets = configure_scanpy_figures(str(io_config.output_dir))
    cmap = viz_assets["cmap"]

    logger.info("Starting integration of scRNAseq and spatial transcriptomics data...")

    logger.info("Loading scRNAseq data from HLCA ...")
    adata_ref = sc.read_h5ad(ref_path)

    logger.info("Loading Xenium data...")
    adata = sc.read_h5ad(io_config.output_dir / "2_dimension_reduction" / "adata.h5ad")

    logger.info("Selecting highly variable genes on reference dataset...")
    sc.pp.highly_variable_genes(
        adata_ref,
        n_top_genes=5000,
        layer="counts",
        flavor="seurat_v3",
        subset=True,  # Subset to highly variable genes for integration
    )
    logger.info(f"Number of highly variable genes selected: {adata_ref.n_vars}")

    # 1. INTEGRATION using scVI and scANVI
    logger.info("Verify adata compatibility for scVI/scANVI integration...")
    verify_counts_layer(adata_ref, "reference")
    verify_counts_layer(adata, "spatial")

    logger.info("Performing integration using: scANVI...")
    logger.info(
        "Step 1. Harmonize scRNAseq reference dataset with STx dataset scVI model..."
    )
    adata_combined, trained_scvi_model = scVI_integration(
        config, adata_ref, adata, module_dir
    )

    logger.info(f"scVI model training complete. adata combined: {adata_combined}")
    logger.info(f"scVI model training complete. Model: {trained_scvi_model}")

    logger.info("Step 2. Transfer labels using scANVI model...")
    adata_combined, trained_scanvi_model = scANVI_label_transfer(
        config, adata_combined, trained_scvi_model, module_dir
    )

    # Extract scANVI predictions and copy to original adata
    logger.info("Extracting predicted labels from scANVI...")
    if adata_combined is not None and trained_scanvi_model is not None:
        adata = extract_predictions_and_visualize(
            adata_combined, trained_scanvi_model, adata, module_dir
        )
    else:
        logger.error("scANVI integration failed. Skipping scANVI predictions.")
        logger.warning("Continuing with ingest integration only...")

I would also train on shared genes. However, scRNA is usually sparser so it might be that you lose some resolution when performing gene subsetting. The code seems correct but a bit bloated. Try adding new columns before concatenating and remove layers before concatenating. It’s make it less error prone. You can transfer labels using the scANVI approach. However for some data we find a kNN classifier in scANVI latent space to perform better. ResolVI only supports spatial data and not integration of spatial and scRNA data.

I tend to perform clustering on spatial data after training ResolVI and let an LLM help with annotation or do it manually and I would encourage to do both and double check the results for the integrated space.

1 Like

Thanks so much for your feedback!

What do you mean by “Try adding new columns before concatenating and remove layers before concatenating.”?

Would you recommend using ResoIVI instead or scVI/scANVI or do you typically run both and compare the annotations?