CUDA is available but Training scVI models is too slow


As shown in the figure. I am planning to use scvi to train models and integrate single-cell data on WSL on my new Windows 11 computer (RTX 4060Ti, i9-14900k). I installed scvi-tools under Python 3.10 and checked that torch.cuda.is_available() is True. However, when I started vae.train(), I was prompted that cuda had been called and the graphics card memory was already occupied. But the training speed is not as expected. I waited 5 minutes and still did not complete an epoch. For comparison, a similar amount of data has been completed in 5 minutes (Python 3.8, on another computer, RTX 4060Ti).

My colleagues and I both had the same problem, we tried Python 3.8 and 3.9 and both failed. In addition, we also installed cuda plug-ins for jax.

We are not professional technicians, so we turn to this place. Is there anyone else who has experienced a similar problem and can provide a solution?

Thank you so much!

Some of my modules version:

scvi-tools               1.2.0
pytorch-lightning        2.4.0
torch                    2.5.1
torch-geometric          2.6.1
torchaudio               2.5.1
torchmetrics             1.6.0
torchvision              0.20.1

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Fri_Jan__6_16:45:21_PST_2023
Cuda compilation tools, release 12.0, V12.0.140
Build cuda_12.0.r12.0/compiler.32267302_0

It looks like the GPU is used. Can you confirm that nvidia-smi shows activity and actually initializes the model?
How large is your dataset?

Hi,
Here’s my data situation:

AnnData object with n_obs × n_vars = 16000 × 3000
    obs: 'sample', 'group1', 'group2', 'batch', 'decontX_contamination', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'passing_mt', 'passing_nUMIs', 'passing_ngenes', 'passing_decontX', 'predicted_labels', 'over_clustering', 'majority_voting', 'conf_score', 'leiden', 'annotation_level1', 'annotation_level2', 'annotation_level3', 'main_celltype', '_scvi_batch', '_scvi_labels'
    var: 'gene_ids', 'feature_types', 'symbol', 'mt', 'ribo', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'annotation_level1_colors', 'annotation_level2_colors', 'annotation_level3_colors', 'dendrogram_leiden', 'group1_colors', 'hvg', 'layers_counts', 'leiden', 'leiden_wilcoxon', 'log1p', 'main_celltype_colors', 'neighbors', 'over_clustering', 'pca', 'sample_colors', 'umap'
    obsm: 'X_pca', 'X_pca_harmony', 'X_scanorama'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'connectivities', 'distances'

nvidia-smi information before starting training is:

Wed Dec  4 13:56:51 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 566.14         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4060 Ti     On  |   00000000:01:00.0  On |                  N/A |
|  0%   34C    P5             15W /  180W |     620MiB /  16380MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

and while do vae.train()

Wed Dec  4 13:58:22 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 566.14         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4060 Ti     On  |   00000000:01:00.0  On |                  N/A |
|  0%   36C    P3             22W /  180W |     839MiB /  16380MiB |     31%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1778      C   /python3.10                                 N/A      |
+-----------------------------------------------------------------------------------------+

After I updated the nvidia driver, I was able to train quickly, it looks like it is related to the driver? I still don’t get it too clear. But it’s already working normally.
[in VSCode]

now my driver version is the latest 566.14

Great that it’s working. It’s likely an underlying torch/CUDA thing and very hard for us to give a good explanation.