Thank you for scvi-tools, looks like great tool set.
I’m trying to use SCANVI to do supervised learning but getting bad results. I did run for many EPOCHs, but same results from first EPOCH.
I did train with a simple MLP in pytorch and already got good results, but wanted to get even better accuracy with SCANVI.
With simple MLP, accuracy ~=90%
With SCANVI, 45%
I’m probably doing something wrong.
I have several .h5ad scRNA datasets from several sources. Total ~1M cells.
I train on all the datasets, except one which is used for test.
On the test dataset I put its labels as -1 (defined as unknown)
I tried to train on raw and also on normalized/log, both gave the same results.
Merge all datasets
process all datasets:
sc.pp.normalize_total(adata, target_sum=1e4) # Tried also without norm and log
sc.pp.highly_variable_genes(adata, n_top_genes=HVG_PARAM, flavor=‘seurat_v3’, subset=True)
Save tested data old labels and Put unknown value on the tested dataset:
test_adata = adata[adata.obs[args.batch_key] == TESTED_DATASET].copy()
adata.obs[args.labels_key][adata.obs[args.batch_key] == TESTED_DATASET] = UNLEBELED
Train and check accuracy each EPOCH:
scvi.model.SCANVI.setup_anndata(adata, labels_key=args.labels_key, unlabeled_category=UNLEBELED, batch_key=args.batch_key)
lvae = scvi.model.SCANVI(adata=adata, n_hidden=512, n_latent=64, n_layers=2)
for i in range(10):
lvae.train(max_epochs=1, batch_size=4096, train_size=0.9)
adata.obsm[“preds”] = lvae.predict()
test_adata.obsm[“preds”] = adata.obsm[“preds”][adata.obs[args.batch_key] ==TESTED_DATASET]
acc_score = accuracy_score(test_adata.obs[args.labels_key], test_adata.obsm[“preds”])
Any idea why the bad results.
I do see that 95% of the test predicted are getting the same label.