Train + Prediction Split

Hello,

From my understanding of scVI assumes that genes fit a ZINB distributions and fits the parameters of this distribution based on some input data (conditional on batch number), and then can output normalized data/counts and a latent space representation.

Is it possible to train the model to fit the parameters on part of my dataset (a portion of each batch), and then predict normalized data/counts on another part of my dataset (a different portion of each batch, with no new batches).

I’ve spent a little bit of time looking at the VAE class but have gotten a little bit confused along the way, can anyone shed some light on this. I’d be happy to clarify my intentions a little bit more here.

Thanks

This pattern works:

train_adata = adata[train_indices]
heldout = adata[heldout_indices]
scvi.model.SCVI.setup_anndata(train_adata, batch_key="batch"...)
model = scvi.model.SCVI(train_adata)
model.train(train_size=1.0)
held_out_latent = model.get_latent_representation(heldout)

as long as heldout does not contain any new categories in heldout.obs["batch"]

Do you not need to do scvi.model.SCVI.setup_anndata(heldout, batch_key="batch"...) too?

By the way, this is an interesting idea for working on humongous datasets. Would be interesting to know how much you lose by using different fractions of the full dataset for training!

/Valentine

You actually don’t and doing so would be redundant. On this call

held_out_latent = model.get_latent_representation(heldout)

the following happens:

  1. The model instance recognizes that this object heldout has not been associated with this model instance.
  2. The model instance (via _validate_anndata) attempts to transfer(impute) the setup state stored from the original anndata setup (when __init__ called) onto heldout.

Even if you were to run setup_anndata on heldout, the model instance would still recognize that heldout has not been validated against the anndata setup state of this model instance, so would try to transfer/impute anyway.

Thus, the tl;dr is that every anndata a model instance sees after training gets validated (triggers imputation of state on new anndata). So setup_anndata is only necessary to run once per model-initializing-anndata object.

As a side note, this behavior will also work (starting in v0.15.0):

    scvi.model.SCVI.setup_anndata(adata)
    m1 = scvi.model.SCVI(adata)
    m1.train(1)

    scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
    m2 = scvi.model.SCVI(adata)
    m2.train(1)

    # m1 recognizes that adata has been associated with a new model
    # m1 replays `setup_anndata` using stored anndata state registry from m1 init.
    # i.e., m1 imputes the initial anndata state onto the anndata that was originally
    # used to train the model as it knows another model might have changed it
    m1.get_latent_representation()
1 Like

Just to follow up here,

I actually did finf that you need to run the following
scvi.model.SCVI.setup_anndata(heldout, batch_key="batch"...) because if you don’t you get a
data_registry key not found error.

If instead of the latent representation, I just wanted the batch correct count matrix, could I just run the following?
held_out_expression = model.get_normalized_expression(heldout)

Seems like the right type of object, but I just wanted to double check that this is as good as feeding in the latent representation back in somewhere…?

Did you by any chance run setup_anndata on the full dataset before splitting up the data?

Yes any method will work this way.

This works:

In [1]: import scvi
Global seed set to 0
ad/Users/adamgayoso/.pyenv/versions/3.9.6/envs/scvi-tools-dev/lib/python3.9/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
^R
  
In [2]: adata = scvi.data.synthetic_iid()

In [3]: bdata = adata[:100].copy()

In [4]: cdata = adata[100:].copy()

In [5]: scvi.model.SCVI.setup_anndata(bdata)

In [6]: model = scvi.model.SCVI(bdata)

In [7]: model.train(1)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/Users/adamgayoso/.pyenv/versions/3.9.6/envs/scvi-tools-dev/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 1/1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 38.81it/s, loss=339, v_num=1]

In [8]: model.get_latent_representation(cdata)
INFO     Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                                                                                                                                                                  
INFO:scvi.model.base._base_model:Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup
Out[9]: 
array([[ 0.08140695,  0.65238976, -0.33962733, ..., -0.5329712 ,
         0.624915  ,  0.03501214],
       [-0.20623109,  0.8029219 , -0.7239966 , ..., -0.5970819 ,
         0.93602294, -0.07002395],
       [-0.01066309,  0.5152629 , -0.35380945, ..., -0.3922435 ,
         1.0264165 ,  0.01999969],
       ...,
       [-0.33414528,  0.4864136 ,  0.40686986, ..., -0.7275924 ,
         0.7657391 ,  0.22869033],
       [ 0.06635017,  0.5002558 , -0.41849643, ..., -0.43128732,
         0.53202224,  0.07187884],
       [-0.8352648 ,  0.71783984, -0.28277442, ...,  0.04417736,
         0.7622042 , -0.11547179]], dtype=float32)