Hello,
I’ve been using the scArches implementation in scvi-tools for query to reference mapping with trained SCVI models. I wanted to explore the trends of reconstruction loss on training and validation set to detect possible overfitting issues. During training of the reference scVI model, I see the expected trend where the training error underestimates the validation error. Conversely during training of the query model I find that the validation error is consistently lower than the training error and I find this puzzling.
Following the example in the reference-mapping tutorial:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata
import scvi
import scanpy as sc
sc.set_figure_params(figsize=(4, 4))
scvi.settings.seed = 94705
url = "https://figshare.com/ndownloader/files/24539828"
adata = sc.read("pancreas.h5ad", backup_url=url)
print(adata)
## Define query dataset
query = np.array([s in ["smartseq2", "celseq2"] for s in adata.obs.tech])
adata_ref = adata[~query].copy()
adata_query = adata[query].copy()
sc.pp.highly_variable_genes(
adata_ref,
n_top_genes=2000,
batch_key="tech",
subset=True
)
adata_query = adata_query[:, adata_ref.var_names].copy()
## Train reference
scvi.model.SCVI.setup_anndata(adata_ref, batch_key="tech", layer="counts")
arches_params = dict(
use_layer_norm="both",
use_batch_norm="none",
encode_covariates=True,
dropout_rate=0.2,
n_layers=2,
)
vae_ref = scvi.model.SCVI(
adata_ref,
**arches_params
)
vae_ref.train(check_val_every_n_epoch=1)
Plotting validation error for reference model training:
plt.plot(vae_ref.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train');
plt.plot(vae_ref.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation');
plt.legend()
Plotting validation error for query model training:
plt.plot(vae_q.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train');
plt.plot(vae_q.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation');
plt.legend()
Any intuition on why this might be going on?
Many thanks in advance,
Emma