import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager
import seaborn as sns
import anndata
import scvi
import scanpy as sc
import scarches as sca
import re
import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
import os
##load data----
adata_full_new = sc.read_h5ad("~/data/adata_all_complete.h5ad")
query = adata_full_new[adata_full_new.obs["dataset_name"]=="Query"].copy()
ref = adata_full_new[adata_full_new.obs["dataset_name"]=="Reference"].copy()
##pre-check
if (query.var.index == ref.var.index).all() or (
ref.var.index == ref.var.index
).all():
print("Gene order is correct.")
else:
print(
"WARNING: your gene order does not match the order of the reference. Fix this before continuing!"
)
#randomly pick some proteins to be missing
rand_cats = np.random.permutation(ref.obs["batch"].astype("category").cat.categories)[:5]
for r in rand_cats:
ref.obsm["protein_counts"][ref.obs["batch"] == r] = 0.0
##setup model ----
scvi.model.TOTALVI.setup_anndata(
ref,
batch_key="batch",
protein_expression_obsm_key="protein_counts",
)
N_EPOCHS=250
scvi.settings.seed = 0
# initialize and train model
arches_params = dict(
use_layer_norm="both",
use_batch_norm="none",
n_layers_decoder=2,
n_layers_encoder=2,
)
vae = scvi.model.TOTALVI(
ref,
**arches_params
)
vae.train(max_epochs=N_EPOCHS, batch_size=256, lr=4e-3)
vae.save("~/out/ref_model", overwrite=True, save_anndata=True)
vae_dir = "~/out/ref_model"
##generate ref latent emb
ref.obsm["X_totalvi_scarches"] = vae.get_latent_representation()
##scvi-tools method
vae_q = scvi.model.TOTALVI.load_query_data(
query_subadata,
vae_dir,
freeze_dropout = True,
)
# now train surgery model
vae_q.train(
200,
lr=4e-3,
batch_size=256,
plan_kwargs=dict(
weight_decay=0.0,
scale_adversarial_loss=0.0
),
# n_steps_kl_warmup=1,
)
full_adata = ref.concatenate(query)
full_adata.obs["X_totalvi_scarches"] = vae_q.get_latent_representation(full_adata)
##generate graph and umap
sc.pp.neighbors(full_adata, use_rep="X_totalvi_scarches", metric="cosine",n_neighbors=30)
sc.tl.umap(full_adata)
##plot
fig, ax = plt.subplots(figsize=(6, 6))
sc.pl.umap(
full_adata,
color=["dataset_name"],
frameon=False,
ncols=1,
title="Reference and query (scArches)",
ax=ax,
palette=None,
# size=5,
)
fig.savefig("~/out/plots/scarches_scvimethod_ref_query.png", bbox_inches="tight", dpi=300)
##Lung atlas method ----
ref_embedding = sc.AnnData(X =ref.obsm["X_totalvi_scarches"], obs=ref.obs)
query_batches = sorted(query.obs["batch"].unique())
batch_variable = "batch"
for batch in query_batches: # this loop is only necessary if you have multiple batches, but will also work for a single batch.
query_subadata = query[query.obs["batch"] == batch,:].copy()
# load model and set relevant variables:
model = scvi.model.TOTALVI.load_query_data(
query_subadata,
vae_dir,
freeze_dropout = True,
)
model._unlabeled_indices = np.arange(query_subadata.n_obs)
model._labeled_indices = []
# now train surgery model using reference model and target adata
model.train(
200,
lr=4e-3,
batch_size=256,
plan_kwargs=dict(
weight_decay=0.0,
scale_adversarial_loss=0.0
),
# n_steps_kl_warmup=1,
)
surgery_path = os.path.join("~/out/model/iter_query/",batch)
if not os.path.exists(surgery_path):
os.makedirs(surgery_path)
model.save(surgery_path, overwrite=True)
emb_df = pd.DataFrame(index=query.obs.index,columns=range(0,ref_embedding.shape[1]))
for batch in query_batches: # from small to large datasets
query_subadata = query[query.obs["batch"] == batch,:].copy()
surgery_path = os.path.join("~/data/model/iter_query/", batch)
model = scvi.model.TOTALVI.load(surgery_path, query_subadata)
query_subadata_latent = sc.AnnData(model.get_latent_representation(query_subadata))
query_subadata_latent.obs = query.obs.loc[query_subadata.obs.index,:]
query_subadata.var.index = query_subadata.var.index.astype(str).tolist()
emb_df.loc[query_subadata.obs.index,:] = query_subadata_latent.X
##create anndata of embedding
query_embedding = sc.AnnData(X=emb_df.values, obs=query.obs)
query_embedding.obs['dataset_name'] = "query"
ref_embedding.obs['dataset_name'] = "reference"
combined_emb = ref_embedding.concatenate(query_embedding, index_unique=None) # index_unique="_", batch_key="dataset") # alternative
##generate graph and umap
sc.pp.neighbors(combined_emb, metric="cosine", n_neighbors=30)
sc.tl.umap(combined_emb)
##plot
fig, ax = plt.subplots(figsize=(6, 6))
sc.pl.umap(
combined_emb,
color=["celltype.l2"],
frameon=False,
ncols=1,
title="Reference and query (scArches)",
ax=ax,
palette=None,
# size=5,
)
fig.savefig("~/out/plots/scarches_hclamethod_ref_query.png", bbox_inches="tight", dpi=300)
##scvi method but with lung atlas method of iterative latent spaces
emb_df2 = pd.DataFrame(index=query.obs.index,columns=range(0,ref_embedding.shape[1]))
for batch in query_batches: # from small to large datasets
query_subadata = query[query.obs["batch"] == batch,:].copy()
##only change is to sample per batch from vae_q as oppose to all at once
query_subadata_latent = sc.AnnData(vae_q.get_latent_representation(query_subadata))
# copy over .obs
query_subadata_latent.obs = query.obs.loc[query_subadata.obs.index,:]
query_subadata.var.index = query_subadata.var.index.astype(str).tolist()
emb_df2.loc[query_subadata.obs.index,:] = query_subadata_latent.X
query_embedding2 = sc.AnnData(X=emb_df2.values, obs=query.obs)
full_embedding2 = ref_embedding.concatenate(query_embedding2, index_unique=None)
full_adata.obs["X_totalvi_scarches_iterative"] = full_embedding2.X
##generate graph and umap
sc.pp.neighbors(full_adata, use_rep="X_totalvi_scarches_iterative", metric="cosine", n_neighbors=30)
sc.tl.umap(full_adata)
##plot
fig, ax = plt.subplots(figsize=(6, 6))
sc.pl.umap(
full_adata,
color=["celltype.l2"],
frameon=False,
ncols=1,
title="Reference and query (scArches)",
ax=ax,
palette=None,
# size=5,
)
fig.savefig("~/out/plots/scarches_scvimethod_iter_ref_query.png", bbox_inches="tight", dpi=300)
The plots are as follows (Ignore the big splat in the middle that looks like a mix of everything - this dataset had major issues with sample prep which scArches helped identify):
SCVI method
HCLA method
SCVI method - iterative latent sampling HCLA style