Thanks I got it to work when I supply a fixed library size (for example 0 in my code block above). However, I now wanted to sample for a LogNormal distribution, as scVI also does. But then I run into problems. I’ve added my code below, but what happens downstream is that the theta and mu are exactly the same, leading to the rate (theta/mu) for the gamma distribution to be zero everywhere. Any idea why this happens or how I can prevent this?
What I’ve implemented is:
with torch.no_grad():
if model.module.gene_likelihood not in ["zinb", "nb", "poisson"]:
raise ValueError("Invalid gene_likelihood.")
# Sampling
qz_m = torch.zeros(n_cells, model.module.n_latent)
qz_v = torch.ones(n_cells, model.module.n_latent)
z = Normal(qz_m, qz_v).sample()
# TODO: allow for different batch indices
dec_batch_index = torch.zeros(n_cells, 1)
# HERE I SAMPLE DIFFERENT LIBRARY SIZES
ln = torch.distributions.log_normal.LogNormal(torch.tensor(6.7649703), torch.tensor(0.16759828)) # Mean and variance of log library sizes in a batch of training data
library = ln.sample(sample_shape=torch.Size([n_cells, 1]))
cat_covs = torch.tensor(np.array(np.repeat([[3.,5.]], n_cells, axis=0), dtype="float32"))
generative_outputs = model.module.generative(z, library, dec_batch_index, cat_covs=cat_covs)
dist = generative_outputs["px"]
if model.module.gene_likelihood == "poisson":
l_train = generative_outputs["px"].rate
l_train = torch.clamp(l_train, max=1e8)
dist = torch.distributions.Poisson(
l_train
) # Shape : (n_samples, n_cells_batch, n_genes)
if n_samples > 1:
exprs = dist.sample().permute(
[1, 2, 0]
) # Shape : (n_cells_batch, n_genes, n_samples)
else:
exprs = dist.sample() # HERE IS WHERE THE ERROR OCCURS
exprs.cpu()
return exprs.numpy()
Error message:
ValueError: Expected parameter rate (Tensor of shape (3582, 1200)) of distribution Gamma(concentration: torch.Size([3582, 1200]), rate: torch.Size([3582, 1200])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])```