ScArches-TotalVI

Hi all,

Having read the recent lung atlas preprint, I have realised that they implement a different approach to atlas extension than what is described in scvi-tools and scarches tutorials.

In the reproducibility code they appear to instantiate a query model from the reference model independently for each batch with frozen dropout. Then construct an extended embedding by concatenating the independently inferred latent space of each batch within the query.

In the tutorials, a query may contain multiple batches with a batch_key. This is then trained with the scarches surgery approach together and a joint latent space is inferred with a newly concatenated “full adata”.

In my mind, these appear to be vastly different approaches, and the lung atlas approach doesn’t seem to actually update the core reference model with the data extensions (HLCA_reproducibility/2_scArches_mapping.ipynb at main · LungCellAtlas/HLCA_reproducibility · GitHub).

Any help trying to understand this approach vs. the tutorials, and what might be best practice would be extremely helpful. Thanks in advance!

These are actually the same thing considering that the only trainable parameters are batch category specific; so therefore I would follow the scvi-tools tutorial as the code is a bit more concise.

To make it more clear, consider also that the latent representation of the reference data does not change after the query training updates.

I see.

But if I wanted future queries to “update” the model and account for previously unseen cells, would I then sequentially update the model? e.g vae > vae_q > vae_q2 >vae_q3 etc.

(although I appreciate the surgery approach is designed not to change the base model weights excessively)

This feels more intuitive as a truer extension of an atlas than iteratively training a base reference model each time. Since iterative training of different datasets never updates the model.

Additionally, I have tried both approaches and it seems that the scvi-tools approach for scarches, results in a less well defined embedding (if I am allowed to interpret the umaps here).

I assume that scvi.model.TOTALVI.load_query_data() automatically detects batch? because it does not allow me to specifically assign a batch_key.

Would there be an explanation why the embeddings should differ between training by batch vs. training all together with batch identifiers in adata.obs?

Ok, having done some tests. The better definition in umap of totalVI embeddings is not dependent on the training method as you state @adamgayoso. It appears to be dependent on how you generate the latent space.

If you iteratively generate the latent embedding per batch within the query and concatenate this with the reference (as the lung atlas code does), the latent representation has better separation and more subjective definition of clusters: e.g. no merging of NK into CD8 and better classification using the random forest model.

If you generate the latent embedding from the merged query and reference, it becomes more homogenous and less well separated.

Validity of umap visual interpretations aside… I imagine this would inadvertently affect de novo clustering

This should not be the case as the reference latent embedding should be the same up to numerical accuracy either way. Can you share your code?

In particular I would make sure the weight_decay is being set properly as per the tutorials

Yes it will use the batch_key originally used to train the reference model.

import time

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager
import seaborn as sns
import anndata
import scvi
import scanpy as sc
import scarches as sca
import re

import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier

import os


##load data----
adata_full_new = sc.read_h5ad("~/data/adata_all_complete.h5ad")

query = adata_full_new[adata_full_new.obs["dataset_name"]=="Query"].copy()

ref = adata_full_new[adata_full_new.obs["dataset_name"]=="Reference"].copy()

##pre-check 
if (query.var.index == ref.var.index).all() or (
  ref.var.index == ref.var.index
).all():
  print("Gene order is correct.")
else:
  print(
    "WARNING: your gene order does not match the order of the reference. Fix this before continuing!"
  )

#randomly pick some proteins to be missing
rand_cats = np.random.permutation(ref.obs["batch"].astype("category").cat.categories)[:5]
for r in rand_cats:
  ref.obsm["protein_counts"][ref.obs["batch"] == r] = 0.0


##setup model ----
scvi.model.TOTALVI.setup_anndata(
    ref,
    batch_key="batch",
    protein_expression_obsm_key="protein_counts",
)


N_EPOCHS=250

scvi.settings.seed = 0


# initialize and train model
arches_params = dict(
    use_layer_norm="both",
    use_batch_norm="none",
    n_layers_decoder=2,
    n_layers_encoder=2,
)

vae = scvi.model.TOTALVI(
    ref, 
    **arches_params
)
vae.train(max_epochs=N_EPOCHS, batch_size=256, lr=4e-3)
vae.save("~/out/ref_model", overwrite=True, save_anndata=True)
vae_dir = "~/out/ref_model"

##generate ref latent emb
ref.obsm["X_totalvi_scarches"] = vae.get_latent_representation()

##scvi-tools method
vae_q = scvi.model.TOTALVI.load_query_data(
    query_subadata, 
    vae_dir,
    freeze_dropout = True,
)

# now train surgery model 
vae_q.train(
    200, 
    lr=4e-3, 
    batch_size=256, 
    plan_kwargs=dict(
        weight_decay=0.0,
        scale_adversarial_loss=0.0
    ),
  # n_steps_kl_warmup=1,
)


full_adata = ref.concatenate(query)
full_adata.obs["X_totalvi_scarches"] = vae_q.get_latent_representation(full_adata)

##generate graph and umap  
sc.pp.neighbors(full_adata, use_rep="X_totalvi_scarches", metric="cosine",n_neighbors=30)
sc.tl.umap(full_adata)

##plot
fig, ax = plt.subplots(figsize=(6, 6))
sc.pl.umap(
  full_adata,
  color=["dataset_name"],
  frameon=False,
  ncols=1,
  title="Reference and query (scArches)",
  ax=ax,
  palette=None,
  #     size=5,
)
fig.savefig("~/out/plots/scarches_scvimethod_ref_query.png", bbox_inches="tight", dpi=300)


##Lung atlas method ----

ref_embedding = sc.AnnData(X =ref.obsm["X_totalvi_scarches"], obs=ref.obs)

query_batches = sorted(query.obs["batch"].unique())
batch_variable = "batch"


for batch in query_batches: # this loop is only necessary if you have multiple batches, but will also work for a single batch.
    query_subadata = query[query.obs["batch"] == batch,:].copy()
    # load model and set relevant variables:
    model = scvi.model.TOTALVI.load_query_data(
        query_subadata, 
        vae_dir,
        freeze_dropout = True,
    )
    model._unlabeled_indices = np.arange(query_subadata.n_obs)
    model._labeled_indices = []
    # now train surgery model using reference model and target adata
    model.train(
        200, 
        lr=4e-3, 
        batch_size=256, 
        plan_kwargs=dict(
            weight_decay=0.0,
            scale_adversarial_loss=0.0
    ),
  # n_steps_kl_warmup=1,
    )
    surgery_path = os.path.join("~/out/model/iter_query/",batch)
    if not os.path.exists(surgery_path):
        os.makedirs(surgery_path)
    model.save(surgery_path, overwrite=True)

    
emb_df = pd.DataFrame(index=query.obs.index,columns=range(0,ref_embedding.shape[1]))

for batch in query_batches: # from small to large datasets
    query_subadata = query[query.obs["batch"] == batch,:].copy()
    surgery_path = os.path.join("~/data/model/iter_query/", batch)
    model = scvi.model.TOTALVI.load(surgery_path, query_subadata)
    query_subadata_latent = sc.AnnData(model.get_latent_representation(query_subadata))
    query_subadata_latent.obs = query.obs.loc[query_subadata.obs.index,:]
    query_subadata.var.index = query_subadata.var.index.astype(str).tolist()
    emb_df.loc[query_subadata.obs.index,:] = query_subadata_latent.X

##create anndata of embedding 
query_embedding = sc.AnnData(X=emb_df.values, obs=query.obs)
    
query_embedding.obs['dataset_name'] = "query"
ref_embedding.obs['dataset_name'] = "reference"

combined_emb = ref_embedding.concatenate(query_embedding, index_unique=None) # index_unique="_", batch_key="dataset") # alternative
   
##generate graph and umap  
sc.pp.neighbors(combined_emb, metric="cosine", n_neighbors=30)
sc.tl.umap(combined_emb)

##plot
fig, ax = plt.subplots(figsize=(6, 6))
sc.pl.umap(
  combined_emb,
  color=["celltype.l2"],
  frameon=False,
  ncols=1,
  title="Reference and query (scArches)",
  ax=ax,
  palette=None,
  #     size=5,
)
fig.savefig("~/out/plots/scarches_hclamethod_ref_query.png", bbox_inches="tight", dpi=300)

    

##scvi method but with lung atlas method of iterative latent spaces
emb_df2 = pd.DataFrame(index=query.obs.index,columns=range(0,ref_embedding.shape[1]))

for batch in query_batches: # from small to large datasets
  query_subadata = query[query.obs["batch"] == batch,:].copy()
  ##only change is to sample per batch from vae_q as oppose to all at once
  query_subadata_latent = sc.AnnData(vae_q.get_latent_representation(query_subadata))
  # copy over .obs
  query_subadata_latent.obs = query.obs.loc[query_subadata.obs.index,:]
  query_subadata.var.index = query_subadata.var.index.astype(str).tolist()
  emb_df2.loc[query_subadata.obs.index,:] = query_subadata_latent.X

query_embedding2 = sc.AnnData(X=emb_df2.values, obs=query.obs)
full_embedding2 = ref_embedding.concatenate(query_embedding2, index_unique=None)

full_adata.obs["X_totalvi_scarches_iterative"] = full_embedding2.X

##generate graph and umap  
sc.pp.neighbors(full_adata, use_rep="X_totalvi_scarches_iterative", metric="cosine", n_neighbors=30)
sc.tl.umap(full_adata)

##plot
fig, ax = plt.subplots(figsize=(6, 6))
sc.pl.umap(
  full_adata,
  color=["celltype.l2"],
  frameon=False,
  ncols=1,
  title="Reference and query (scArches)",
  ax=ax,
  palette=None,
  #     size=5,
)
fig.savefig("~/out/plots/scarches_scvimethod_iter_ref_query.png", bbox_inches="tight", dpi=300)

The plots are as follows (Ignore the big splat in the middle that looks like a mix of everything - this dataset had major issues with sample prep which scArches helped identify):

SCVI method

HCLA method

SCVI method - iterative latent sampling HCLA style

query_embedding2 = sc.AnnData(X=emb_df2.values, obs=query.obs)
  
full_adata.obs["X_totalvi_scarches_iterative"] = query_embedding2.X


##generate graph and umap  
sc.pp.neighbors(full_adata, use_rep="X_totalvi_scarches_iterative", metric="cosine", n_neighbors=30)
sc.tl.umap(full_adata)

Thank you for sharing. I’m a bit confused how this code is working? The code here has no "X_totalvi_scarches_iterative" for the reference data, and this in .obs and not .obsm?

Something else important to keep in mind is that what you’re calling the HCLA method trains each query batch for 200 epochs, where the “scvi” method trains all the query batches for 200 epochs total. That’s a lot more training for the HCLA method. You can try changing 200 → 400 for the “SCVI” method.

1 Like

I think I might see the issue, which could be a recent bug on our side when the full data is concatenated.

Ok not a bug, but what’s happening is two things:

  1. Numerical accuracy of PyTorch (see here)
  2. Much more query training happening in your “HCLA” flavor as I mentioned above.

Thank you for sharing. I’m a bit confused how this code is working? The code here has no "X_totalvi_scarches_iterative" for the reference data, and this in .obs and not .obsm ?

Sorry, mistake when copying and pasting/removing project identifiable object names. corrected now

Something else important to keep in mind is that what you’re calling the HCLA method trains each query batch for 200 epochs, where the “scvi” method trains all the query batches for 200 epochs total. That’s a lot more training for the HCLA method. You can try changing 200 → 400 for the “SCVI” method.

Of course! Now this makes sense, had completely escaped me. Will give this a go and report back. Thanks a bunch for the feedback!

Actually, thinking on this a bit more. This isn’t the reason.

For " SCVI method - iterative latent sampling HCLA style " The only difference is the per batch passing through the trained vae_q rather than the number of epochs used to train vae_q.

Therefore it is something to do with either the concatenation, or that batch specific passing through the vae generates different latent_emb coordinates than passing them when fully concatenated. Maybe the numerical accuracy you suggested.

I tested locally and there doesn’t seem to be a bug outside of numerical accuracy. Something else to try is for the first approach you took to mess with the umap random seed, as this is a stochastic algorithm and maybe you got a “bad” seed qualitatively