This works:
In [1]: import scvi
Global seed set to 0
ad/Users/adamgayoso/.pyenv/versions/3.9.6/envs/scvi-tools-dev/lib/python3.9/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
^R
In [2]: adata = scvi.data.synthetic_iid()
In [3]: bdata = adata[:100].copy()
In [4]: cdata = adata[100:].copy()
In [5]: scvi.model.SCVI.setup_anndata(bdata)
In [6]: model = scvi.model.SCVI(bdata)
In [7]: model.train(1)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/Users/adamgayoso/.pyenv/versions/3.9.6/envs/scvi-tools-dev/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
rank_zero_warn(
Epoch 1/1: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 38.81it/s, loss=339, v_num=1]
In [8]: model.get_latent_representation(cdata)
INFO Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup
INFO:scvi.model.base._base_model:Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup
Out[9]:
array([[ 0.08140695, 0.65238976, -0.33962733, ..., -0.5329712 ,
0.624915 , 0.03501214],
[-0.20623109, 0.8029219 , -0.7239966 , ..., -0.5970819 ,
0.93602294, -0.07002395],
[-0.01066309, 0.5152629 , -0.35380945, ..., -0.3922435 ,
1.0264165 , 0.01999969],
...,
[-0.33414528, 0.4864136 , 0.40686986, ..., -0.7275924 ,
0.7657391 , 0.22869033],
[ 0.06635017, 0.5002558 , -0.41849643, ..., -0.43128732,
0.53202224, 0.07187884],
[-0.8352648 , 0.71783984, -0.28277442, ..., 0.04417736,
0.7622042 , -0.11547179]], dtype=float32)