Supervised training with scanvi does not converge, any clue why?

Hi,
Thank you for scvi-tools, looks like great tool set.
I’m trying to use SCANVI to do supervised learning but getting bad results. I did run for many EPOCHs, but same results from first EPOCH.
I did train with a simple MLP in pytorch and already got good results, but wanted to get even better accuracy with SCANVI.
With simple MLP, accuracy ~=90%
With SCANVI, 45%
I’m probably doing something wrong.
Process:
I have several .h5ad scRNA datasets from several sources. Total ~1M cells.
I train on all the datasets, except one which is used for test.
On the test dataset I put its labels as -1 (defined as unknown)
I tried to train on raw and also on normalized/log, both gave the same results.
the steps:

  1. Merge all datasets

  2. process all datasets:
    sc.pp.filter_cells(adata, min_genes=MIN_GENES)
    sc.pp.filter_genes(adata, min_cells=MIN_CELLS)
    sc.pp.normalize_total(adata, target_sum=1e4) # Tried also without norm and log
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=HVG_PARAM, flavor=‘seurat_v3’, subset=True)

  3. Save tested data old labels and Put unknown value on the tested dataset:
    test_adata = adata[adata.obs[args.batch_key] == TESTED_DATASET].copy()
    adata.obs[args.labels_key][adata.obs[args.batch_key] == TESTED_DATASET] = UNLEBELED

  4. Train and check accuracy each EPOCH:

scvi.model.SCANVI.setup_anndata(adata, labels_key=args.labels_key, unlabeled_category=UNLEBELED, batch_key=args.batch_key)
lvae = scvi.model.SCANVI(adata=adata, n_hidden=512, n_latent=64, n_layers=2)
for i in range(10):
lvae.train(max_epochs=1, batch_size=4096, train_size=0.9)
adata.obsm[“preds”] = lvae.predict()
test_adata.obsm[“preds”] = adata.obsm[“preds”][adata.obs[args.batch_key] ==TESTED_DATASET]
acc_score = accuracy_score(test_adata.obs[args.labels_key], test_adata.obsm[“preds”])

Any idea why the bad results.
I do see that 95% of the test predicted are getting the same label.