Function in scanvi that can help compute the probability of having a latent cell state given a cell type

Hi,

I would like to compute p(z|c) from a trained scanvi model where z represents a latent state and c represents a cell-type. I know in the scanvi framework there is a learnable probability p(z|c,u) where u is just a latent cell-type specific state sampled from N(0,1), so I think based on that, computing p(z|c) is possible. Could you tell me which function I need to call to compute p(z|c,u) in scanvi?

Thanks!

I think you are looking for something like:

model.get_latent_representation(return_dist=True,indices=cell_type_indices)

it will return a tuple of mean and variance of z, for the specific cell_type_indices you will insert (e.g np.where(adata.obs.labels==“label_1”) will give the indices for label_1 cell type class in the adata)