How to dynamically mask categorical_covariate_keys during training?

I run setup_anndata(categorical_covariate_keys=['key1','key2',...]), and need a way to use these covariates selectively in training_step() (manual optimization). Is there a way to do this? Thanks!

Can you describe in greater detail what you’d like?

Those covariates will be in one large matrix and you can do masking on that.

Oh got it, thanks. I think I just need to do something like:

    def _get_inference_input(self, tensors, cat_idx):
        x = tensors[REGISTRY_KEYS.X_KEY]
        cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY)[:,cat_idx]

I think I’m good here, but can you please confirm that the column order of AnnData passed to model.setup_anndata() is maintained by tensors.get(REGISTRY_KEYS.CAT_COVS_KEY)?

Thanks!

Same order!

1 Like