GPU Utilization is 0 Despite JAX Occupying GPU Memory

Hello,

I’m running a model using JAX, and I see that it occupies GPU memory, but the GPU utilization remains at 0%. The following message appeared during the execution:


Could someone help me understand why the GPU is not being utilized despite the memory being allocated? What could be causing this issue, and how can I ensure that JAX is properly using the GPU for computation?

Thank you in advance for your help!

Hi. Are you sure that the GPU is not being used? Could you increase the batch size to 4096 to see if this actually gets GPU utilization up.