def scVI_integration(config, adata_ref, adata, module_dir):
"""Integrates STx data with a reference scRNA-seq dataset using scVI.
This function performs harmonizes a reference single-cell dataset
(e.g., HLCA) to a STx dataset (e.g., Xenium).
Args:
config (SimpleNamespace or dict): Configuration object containing
module-specific parameters, including the module name used to
construct figure filenames.
adata_ref (anndata.AnnData): Reference scRNA-seq AnnData object containing
precomputed embeddings (PCA or UMAP) and cell-type annotations.
Should have matching genes with adata_ingest.
adata (anndata.AnnData): Spatial dataset formatted for integration,
aligned by gene set with the reference (same genes, same order).
module_dir (Path): Output directory for saving results
Returns:
tuple: (adata_combined, scanvi_model) - Combined dataset and trained model
"""
logger.info("Integrating data using scVI...")
# Training parameters
MAX_EPOCHS_SCVI = 200
# Log available layers for debugging
logger.info(f"Reference layers: {list(adata_ref.layers.keys())}")
logger.info(f"Spatial layers: {list(adata.layers.keys())}")
# Log counts layer shapes (already verified to exist)
logger.info(f"Reference counts shape: {adata_ref.layers['counts'].shape}")
logger.info(f"Spatial counts shape: {adata.layers['counts'].shape}")
# Ensure REF_CELL_LABEL_COL column exists in reference
adata.obs[REF_CELL_LABEL_COL] = (
"STx_UNKNOWN" # ensure column exists in spatial data
)
logger.info("Verifying datasets for scVI integration...")
scVI_integration_check(adata_ref, batch_key=BATCH_COL, cell_type=REF_CELL_LABEL_COL)
scVI_integration_check(adata, batch_key=BATCH_COL, cell_type=REF_CELL_LABEL_COL)
# Combine datasets
logger.info("Combining reference and spatial data...")
adata_combined = anndata.concat(
[adata_ref, adata],
join="inner", # only keeps genes present in both datasets
label=BATCH_COL,
keys=["Ref", "STx"],
index_unique="_",
)
# Verify counts layer was preserved during concat
if "counts" not in adata_combined.layers:
logger.error("Counts layer lost during concat! Manually preserving...")
# Manually recreate counts layer from original data
ref_mask = adata_combined.obs[BATCH_COL] == "Ref"
stx_mask = adata_combined.obs[BATCH_COL] == "STx"
# Initialize counts layer
adata_combined.layers["counts"] = adata_combined.X.copy() # temporary
# Fill with actual counts
adata_combined.layers["counts"][ref_mask, :] = adata_ref.layers["counts"]
adata_combined.layers["counts"][stx_mask, :] = adata.layers["counts"]
logger.info("Counts layer manually restored!")
else:
logger.info("âś“ Counts layer preserved during concat")
logger.info(f"Combined data layers: {list(adata_combined.layers.keys())}")
# Setup scVI
logger.info("Setting up scVI model...")
SCVI.setup_anndata(
adata_combined,
layer="counts", # must be counts layer
batch_key=BATCH_COL, # variable you want to perform harmonization over
)
logger.info("Initializing scVI model...")
scvi_model = SCVI(adata_combined, n_layers=2, n_latent=30, n_hidden=128)
logger.info("Training SCVI model...")
scvi_model.train(max_epochs=MAX_EPOCHS_SCVI, batch_size=128)
logger.info("Obtain and visualize latent representation...")
adata_combined.obsm[SCVI_LATENT_KEY] = scvi_model.get_latent_representation()
sc.pp.pca(adata_combined)
sc.pp.neighbors(adata_combined, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata_combined)
sc.pl.umap(
adata_combined,
color=[BATCH_COL, REF_CELL_LABEL_COL],
frameon=False,
ncols=1,
save=f"_{config.module_name}_scvi_umap.png",
)
logger.info("Saving scANVI model...")
scvi_model.save(module_dir / "_scvi_ref", overwrite=True)
return adata_combined, scvi_model
def scANVI_label_transfer(config, adata_combined, scvi_model, module_dir):
"""Integrates STx data with a reference scRNA-seq dataset using scVI.
This function performs label transfer from a reference single-cell dataset
(e.g., HLCA) to a STx dataset (e.g., Xenium).
scANVI uses a semi-supervised approach to transfer cell-type labels.
Args:
adata_combined (anndata.AnnData): Reference scRNA-seq AnnData object containing
precomputed embeddings (PCA or UMAP) and cell-type annotations.
Should have matching genes with adata_ingest.
scvi_model (scvi.model.SCVI): Trained scVI model on combined data.
config (SimpleNamespace or dict): Configuration object containing
module_dir (Path): Output directory for saving results
Returns:
tuple: (adata_combined, scanvi_model) - Combined dataset and trained model
"""
# Set constant variables
MAX_EPOCHS_SCANVI = 200
# Format labels for scANVI
# Create scANVI labels (reference has labels, spatial is 'Unknown')
SCANVI_CELLTYPE_KEY = (
"training_celltype_scanvi" # new column that scANVI will use training labels.
)
UNLABELED_CATEGORY = (
"Unknown" # placeholder for cells that do not have known labels aka STx cells
)
# Set the label column (SCANVI_CELLTYPE_KEY) to "Unknown" for all cells initially.
# This ensures that STx cells are marked as unlabeled before we assign ref labels
adata_combined.obs[SCANVI_CELLTYPE_KEY] = UNLABELED_CATEGORY
# Check columns in adata_combined
logger.info(f"Columns in adata_combined: {list(adata_combined.obs.columns)}")
# Ensure the batch column exists
if BATCH_COL not in adata_combined.obs.columns:
logger.error(f"Batch column '{BATCH_COL}' not found")
return None, None
# Assign real labels to reference cells
ref_mask = adata_combined.obs[BATCH_COL] == "Ref"
# Ensure reference label column exists
if REF_CELL_LABEL_COL not in adata_combined.obs.columns:
logger.error(f"Cell type column '{REF_CELL_LABEL_COL}' not found")
return None, None
# Assign reference labels safely
adata_combined.obs.loc[ref_mask, SCANVI_CELLTYPE_KEY] = adata_combined.obs.loc[
ref_mask, REF_CELL_LABEL_COL
].astype(str)
# Convert the column to categorical (recommended for scANVI)
adata_combined.obs[SCANVI_CELLTYPE_KEY] = adata_combined.obs[
SCANVI_CELLTYPE_KEY
].astype("category")
logger.info(
f"Reference cells with labels: {ref_mask.sum()}"
) # number of reference cells with labels.
logger.info(
f"Percent ref cells labeled: {100 * ref_mask.sum() / adata_combined.n_obs:.2f}%"
)
logger.info(
f"Spatial cells (unlabeled): {(~ref_mask).sum()}"
) # number of spatial (unlabeled) cells.
logger.info(
f"Percent Spatial cells: {100 * (~ref_mask).sum() / adata_combined.n_obs:.2f}%"
)
logger.info(
f"Unique cell types: {adata_combined.obs[SCANVI_CELLTYPE_KEY].value_counts()}"
)
# Initialize scANVI from trained scVI
logger.info("Initializing scANVI model...")
SCANVI.setup_anndata(
adata_combined,
labels_key=SCANVI_CELLTYPE_KEY,
unlabeled_category=UNLABELED_CATEGORY,
batch_key=BATCH_COL,
)
scanvi_model = SCANVI.from_scvi_model(
scvi_model,
unlabeled_category=UNLABELED_CATEGORY,
labels_key=SCANVI_CELLTYPE_KEY,
)
# Train scANVI
logger.info("Training scANVI model...")
scanvi_model.train(max_epochs=MAX_EPOCHS_SCANVI, batch_size=128)
logger.info("scANVI training completed!")
logger.info("Get latent representation...")
adata_combined.obsm[SCANVI_LATENT_KEY] = scanvi_model.get_latent_representation()
logger.info("Visualizing scANVI latent space...")
sc.pp.pca(adata_combined)
sc.pp.neighbors(adata_combined, use_rep=SCANVI_LATENT_KEY)
sc.tl.umap(adata_combined)
sc.pl.umap(
adata_combined,
color=[BATCH_COL, REF_CELL_LABEL_COL],
frameon=False,
ncols=1,
save=f"_{config.module_name}_scanvi_umap.png",
)
logger.info("Saving scANVI model...")
scanvi_model.save(module_dir / "_scanvi_ref", overwrite=True)
return adata_combined, scanvi_model
def run_integration(config: IntegrateModuleConfig, io_config: IOConfig):
"""Integrate scRNAseq and STx data using scANVI and ingest.
adata_ref_subset and adata_ingest are used for ingest integration.
adata_ref and data are used for scANVI integration.
Args:
config (IntegrateModuleConfig): Integration module configuration object.
io_config (IOConfig): IO configuration object.
Returns:
None
"""
# Variables
# Name of the column to store label transfer results in adata.obs
module_dir = io_config.output_dir / config.module_name
# Paths to input data
ref_path = io_config.ref_path
gene_id_dict_path = io_config.gene_id_dict_path
# Create output directories if they do not exist
module_dir.mkdir(exist_ok=True)
# Set figure directory for this module (overrides global setting)
sc.settings.figdir = module_dir
# Get shared colormap from global visualization settings
# This ensures consistency across all modules
viz_assets = configure_scanpy_figures(str(io_config.output_dir))
cmap = viz_assets["cmap"]
logger.info("Starting integration of scRNAseq and spatial transcriptomics data...")
logger.info("Loading scRNAseq data from HLCA ...")
adata_ref = sc.read_h5ad(ref_path)
logger.info("Loading Xenium data...")
adata = sc.read_h5ad(io_config.output_dir / "2_dimension_reduction" / "adata.h5ad")
logger.info("Selecting highly variable genes on reference dataset...")
sc.pp.highly_variable_genes(
adata_ref,
n_top_genes=5000,
layer="counts",
flavor="seurat_v3",
subset=True, # Subset to highly variable genes for integration
)
logger.info(f"Number of highly variable genes selected: {adata_ref.n_vars}")
# 1. INTEGRATION using scVI and scANVI
logger.info("Verify adata compatibility for scVI/scANVI integration...")
verify_counts_layer(adata_ref, "reference")
verify_counts_layer(adata, "spatial")
logger.info("Performing integration using: scANVI...")
logger.info(
"Step 1. Harmonize scRNAseq reference dataset with STx dataset scVI model..."
)
adata_combined, trained_scvi_model = scVI_integration(
config, adata_ref, adata, module_dir
)
logger.info(f"scVI model training complete. adata combined: {adata_combined}")
logger.info(f"scVI model training complete. Model: {trained_scvi_model}")
logger.info("Step 2. Transfer labels using scANVI model...")
adata_combined, trained_scanvi_model = scANVI_label_transfer(
config, adata_combined, trained_scvi_model, module_dir
)
# Extract scANVI predictions and copy to original adata
logger.info("Extracting predicted labels from scANVI...")
if adata_combined is not None and trained_scanvi_model is not None:
adata = extract_predictions_and_visualize(
adata_combined, trained_scanvi_model, adata, module_dir
)
else:
logger.error("scANVI integration failed. Skipping scANVI predictions.")
logger.warning("Continuing with ingest integration only...")