We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
See jax-ml/jax#7654
We should deduplicate reducers when converting from MHLO to HLO. e.g. compare:
In [1]: import jax In [2]: import jax.numpy as jnp In [3]: def f(x, y): return jnp.sum(x) + jnp.sum(y) In [4]: print(jax.jit(f).lower(jnp.arange(10), jnp.arange(15)).compiler_ir()) module @jit_f.2 { func.func public @main(%arg0: tensor<10xi32>, %arg1: tensor<15xi32>) -> tensor<i32> { %0 = mhlo.constant dense<0> : tensor<i32> %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [0] : (tensor<10xi32>, tensor<i32>) -> tensor<i32> reducer(%arg2: tensor<i32>, %arg3: tensor<i32>) { %5 = mhlo.add %arg2, %arg3 : tensor<i32> "mhlo.return"(%5) : (tensor<i32>) -> () } %2 = mhlo.constant dense<0> : tensor<i32> %3 = mhlo.reduce(%arg1 init: %2) across dimensions = [0] : (tensor<15xi32>, tensor<i32>) -> tensor<i32> reducer(%arg2: tensor<i32>, %arg3: tensor<i32>) { %5 = mhlo.add %arg2, %arg3 : tensor<i32> "mhlo.return"(%5) : (tensor<i32>) -> () } %4 = mhlo.add %1, %3 : tensor<i32> return %4 : tensor<i32> } }
and
In [6]: print(jax.jit(f).lower(jnp.arange(10), jnp.arange(15)).compiler_ir(dialect="hlo").as_hlo_text()) HloModule jit_f.4, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_0.4 { Arg_0.5 = s32[] parameter(0) Arg_1.6 = s32[] parameter(1) ROOT add.7 = s32[] add(Arg_0.5, Arg_1.6) } region_1.9 { Arg_0.10 = s32[] parameter(0) Arg_1.11 = s32[] parameter(1) ROOT add.12 = s32[] add(Arg_0.10, Arg_1.11) } ENTRY main.15 { Arg_0.1 = s32[10]{0} parameter(0) constant.3 = s32[] constant(0) reduce.8 = s32[] reduce(Arg_0.1, constant.3), dimensions={0}, to_apply=region_0.4 Arg_1.2 = s32[15]{0} parameter(1) reduce.13 = s32[] reduce(Arg_1.2, constant.3), dimensions={0}, to_apply=region_1.9 ROOT add.14 = s32[] add(reduce.8, reduce.13) }
It would be great to merge region_0.4 and region_1.9 for readability of the HLO. Some computations end up with hundreds of reducers.
region_0.4
region_1.9
@cheshire
The text was updated successfully, but these errors were encountered:
Some simple identical function merging based on OperationEquivalence should be able to catch this I think.
Sorry, something went wrong.
No branches or pull requests
See jax-ml/jax#7654
We should deduplicate reducers when converting from MHLO to HLO. e.g. compare:
and
It would be great to merge
region_0.4
andregion_1.9
for readability of the HLO. Some computations end up with hundreds of reducers.@cheshire
The text was updated successfully, but these errors were encountered: