Question about get_normalized_expression computation

I have trained a TOTALVI model and I am using the get_normalized_expression function to get imputed values for one gene and one protein. It is using mostly CPUs (128 threads) and not the GPU (Volatile GPU-Util = 8%, 2391 Mb). Is it correct that it is a CPU instead of GPU heavy calculation?

Generally, yes, but the ratio of how much CPU will work vs GPU depends on the get_normalized_expression function parameters.

If, for example, you increase the batch_size and reduce the n_samples, I expect you will see more traffic on GPU than CPU.

But it also depends on other parameters such as: transform_batch, indices,… anything really.

You can follow it if you set the silent parameter to False and track CPU and GPU usage.

Thanks for the suggestion. I have increased batch size and have noticed more traffic on GPU than CPU. CPU usage is now around the n_samples, and there is more GPU memory being used.

Is silent a parameter for get_normalized_expression or some other function?

I am also concerned that get_normalized_expression might use mdata[‘rna’].X instead of mdata[‘rna’].layers[rna_layer] (whatever rna_layer we specified when we called scvi.model.TOTALVI.setup_mudata). On github, it says:

post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
...
for tensors in post:
    x = tensors[REGISTRY_KEYS.X_KEY]
    y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]

I think REGISTRY_KEYS.X_KEY = ‘X’? Does self._make_data_loader know to load the specified rna_layer?

silent is for get_normalized_expression, it will just show you the progress bar of the function, as it works per batch; it might correlate to the GPU and CPU usage.

Its ok in terms of the backend. REGISTRY_KEYS.X_KEY = ‘X’ indeed, but it is used with the tensors (or data loading) object, not the mudata. This can be done after model was registered and setup with the mudata/adata, so we know how to map RNA/protein data from this structure to the actual data loading and tensors (the actual thing that flows in the network).