Replies: 1 comment 7 replies
-
It sounds like what you want is a You could probably write this function yourself by scaling the full array by the max value, then computing the |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I’m working on a JAX project and encountered a challenge related to memory efficiency. I have a large 1D array of shape
(N,)
, which is partitioned intod
contiguous segments of known, but unequal, sizes:(n₁,), (n₂,), ..., (n_d,)
, such that they form a complete partition of the original array.My goal is to compute the
logsumexp
over each subarray efficiently. While the shapes of the segments are known at compile time, the unequal sizes seem to prevent fully vectorized operations, and the current implementation consumes significant memory.Could you advise on the most memory-efficient approach to compute the
logsumexp
over these partitions in JAX? I'm particularly interested in strategies that avoid materializing padded arrays or large intermediate representations.Any insights or suggestions would be greatly appreciated.
Thank you
Beta Was this translation helpful? Give feedback.
All reactions