You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@rahulbatra85
I managed to get jax to train resnet50 using the scenic library and the latest docker images, thanks to #18747.
When I change my code from float32 to float16, I get an approximate ~2x speedup. However, when I enable bfloat16, the system runs at the float32 speed. Am I correct in assuming that the docker + XTX + CDNA3 + ROCM 6.1 system is not generating bfloat16 code yet?
System info (python version, jaxlib version, accelerator, etc.)
Description
@rahulbatra85
I managed to get jax to train resnet50 using the scenic library and the latest docker images, thanks to #18747.
When I change my code from
float32
tofloat16
, I get an approximate ~2x speedup. However, when I enablebfloat16
, the system runs at the float32 speed. Am I correct in assuming that the docker + XTX + CDNA3 + ROCM 6.1 system is not generating bfloat16 code yet?System info (python version, jaxlib version, accelerator, etc.)
ubuntu 22.04 lts
rocm 6.1
7900 xtx
rocm/jax:latest image
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.0 (default, Apr 9 2024, 03:46:30) [GCC 9.4.0]
jax.devices (1 total, 1 local): [rocm(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='mars', release='5.15.0-105-generic', version='#115-Ubuntu SMP Mon Apr 15 09:52:04 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: