Hey there,
I am implementing a simple CVAE using scvi-tools following these steps in a jupyter notebook:
- calling
CVAE.setupa_anndata(adata)
:
adata = scn.models.CellVAE.setup_anndata(adata,
label_key='cell_type',
batch_key='batch',
layer=None,
)
- instantiating my model:
model = scn.models.CellVAE(adata,
n_latent=64,
**config
)
then at the second step I face this error.
AttributeError: 'CellVAE' object has no attribute '_adata_manager'
class CellVAE(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
"""
Skeleton for an scvi-tools model.
Please use this skeleton to create new models.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**model_kwargs
Keyword args for :class:`~mypackage.MyModule`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> mypackage.CellVAE.setup_anndata(adata, batch_key="batch")
>>> vae = mypackage.CellVAE(adata)
>>> vae.train()
>>> adata.obsm["X_mymodel"] = vae.get_latent_representation()
"""
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
**model_kwargs,
):
super(CellVAE, self).__init__(adata)
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, self.summary_stats["n_batch"]
)
# self.summary_stats provides information about anndata dimensions and other tensor info
self.module = CellVAEModule(
n_input=self.summary_stats["n_vars"],
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
n_batch=self.summary_stats['n_batch'],
library_log_means=library_log_means,
library_log_vars=library_log_vars,
**model_kwargs,
)
self._model_summary_string = "Overwrite this attribute to get an informative representation for your model"
# necessary line to get params that will be used for saving/loading
self.init_params_ = self._get_init_params(locals())
logger.info("The model has been initialized")
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
batch_key: Optional[str] = None,
label_key: Optional[str] = None,
layer: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
**kwargs,
):
"""
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, label_key),
CategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
"""
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
**model_kwargs,
):
super(CellVAE, self).__init__(adata)
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, self.summary_stats["n_batch"]
)
...
scvi-tools version: 0.18.0. - python version: 3.9